aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-26 09:49:23 -0800
committerGravatar GitHub <noreply@github.com>2018-01-26 09:49:23 -0800
commit7611eef0208fa1413880a704e622e57bbcfad0d6 (patch)
tree4cfbf626431c11f1f1e5219895800de7114edc03
parent9384314e0cbf6f315d870200fc5abe421deefcab (diff)
parent78021a9a70923f1fdaa65b41271ad0ea70cd7e67 (diff)
Merge branch 'master' into tensorrt
-rw-r--r--RELEASE.md2
-rw-r--r--WORKSPACE8
-rw-r--r--configure.py116
-rw-r--r--tensorflow/BUILD15
-rw-r--r--tensorflow/c/BUILD12
-rw-r--r--tensorflow/c/c_api.cc15
-rw-r--r--tensorflow/c/c_api_test.cc31
-rw-r--r--tensorflow/c/c_test_util.h6
-rw-r--r--tensorflow/c/eager/c_api.cc28
-rw-r--r--tensorflow/c/eager/c_api.h18
-rw-r--r--tensorflow/c/eager/c_api_internal.h10
-rw-r--r--tensorflow/c/eager/c_api_test.cc49
-rw-r--r--tensorflow/c/python_api.h6
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/cc/framework/cc_op_gen.h6
-rw-r--r--tensorflow/cc/framework/grad_op_registry.h6
-rw-r--r--tensorflow/cc/framework/gradient_checker.h6
-rw-r--r--tensorflow/cc/framework/gradients.h6
-rw-r--r--tensorflow/cc/framework/ops.h6
-rw-r--r--tensorflow/cc/framework/scope.h6
-rw-r--r--tensorflow/cc/framework/scope_internal.h6
-rw-r--r--tensorflow/cc/framework/testutil.h6
-rw-r--r--tensorflow/cc/framework/while_gradients.h6
-rw-r--r--tensorflow/cc/gradients/grad_testutil.h6
-rw-r--r--tensorflow/cc/ops/const_op.h6
-rw-r--r--tensorflow/cc/ops/standard_ops.h6
-rw-r--r--tensorflow/cc/ops/while_loop.h6
-rw-r--r--tensorflow/cc/profiler/profiler.h6
-rw-r--r--tensorflow/cc/saved_model/constants.h6
-rw-r--r--tensorflow/cc/saved_model/loader.h6
-rw-r--r--tensorflow/cc/saved_model/signature_constants.h6
-rw-r--r--tensorflow/cc/saved_model/tag_constants.h6
-rw-r--r--tensorflow/cc/tools/freeze_saved_model.h6
-rw-r--r--tensorflow/cc/tools/freeze_saved_model_test.cc2
-rw-r--r--tensorflow/cc/training/coordinator.h6
-rw-r--r--tensorflow/cc/training/queue_runner.h6
-rw-r--r--tensorflow/compiler/tests/BUILD5
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py22
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py15
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py6
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_util.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h20
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc36
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc6
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc43
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h6
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h4
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc10
-rw-r--r--tensorflow/compiler/xla/execution_options_util.h6
-rw-r--r--tensorflow/compiler/xla/iterator_util.h6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc5
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.h6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h6
-rw-r--r--tensorflow/compiler/xla/literal_util.cc10
-rw-r--r--tensorflow/compiler/xla/map_util.h21
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc73
-rw-r--r--tensorflow/compiler/xla/primitive_util.h62
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py48
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc59
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc139
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc46
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc32
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc281
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matvec.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/shape_partition.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc2
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h6
-rw-r--r--tensorflow/compiler/xla/service/executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/executable.h21
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h65
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc120
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc102
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc121
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc206
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc118
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_profile_printer.cc45
-rw-r--r--tensorflow/compiler/xla/service/hlo_profile_printer.h85
-rw-r--r--tensorflow/compiler/xla/service/hlo_profile_printer_data.proto60
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h6
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ops.h6
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc3
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.h6
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
-rw-r--r--tensorflow/compiler/xla/service/service.h8
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.h46
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h6
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h6
-rw-r--r--tensorflow/compiler/xla/sparse_index_array.h6
-rw-r--r--tensorflow/compiler/xla/statusor_internals.h6
-rw-r--r--tensorflow/compiler/xla/tests/BUILD46
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc4
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h5
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc36
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc23
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc28
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.h6
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc257
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc149
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h31
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc3
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc72
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc76
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc132
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc162
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc189
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc2
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc128
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc2
-rw-r--r--tensorflow/compiler/xla/util.cc15
-rw-r--r--tensorflow/compiler/xla/util.h18
-rw-r--r--tensorflow/compiler/xla/window_util.cc31
-rw-r--r--tensorflow/compiler/xla/window_util.h18
-rw-r--r--tensorflow/compiler/xla/xla.proto4
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.h6
-rw-r--r--tensorflow/contrib/batching/BUILD47
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h6
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h6
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h6
-rw-r--r--tensorflow/contrib/batching/kernels/BUILD34
-rw-r--r--tensorflow/contrib/batching/ops/batch_ops.cc164
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py12
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h6
-rw-r--r--tensorflow/contrib/batching/test_util/fake_clock_env.h6
-rw-r--r--tensorflow/contrib/batching/util/periodic_function.h6
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py384
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers.py38
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py1670
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py (renamed from tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py)92
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/layers_util.py11
-rw-r--r--tensorflow/contrib/boosted_trees/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc27
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/batch_features.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/example.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/macros.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/optional_value.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/random.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h6
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h6
-rw-r--r--tensorflow/contrib/boosted_trees/ops/quantile_ops.cc1
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py100
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py2
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py18
-rw-r--r--tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h6
-rw-r--r--tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h22
-rw-r--r--tensorflow/contrib/boosted_trees/resources/stamped_resource.h6
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc28
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h6
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h6
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt3
-rw-r--r--tensorflow/contrib/cmake/external/snappy.cmake2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt7
-rw-r--r--tensorflow/contrib/cmake/python_sanity_test.py128
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake10
-rw-r--r--tensorflow/contrib/coder/kernels/range_coder.h6
-rw-r--r--tensorflow/contrib/coder/kernels/range_coder_ops_util.h6
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py75
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py18
-rw-r--r--tensorflow/contrib/distributions/BUILD29
-rw-r--r--tensorflow/contrib/distributions/__init__.py3
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py94
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py144
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py52
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py148
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py208
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py25
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py282
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py300
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_test.py48
-rw-r--r--tensorflow/contrib/eager/python/datasets.py13
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py14
-rw-r--r--tensorflow/contrib/eager/python/evaluator.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/BUILD10
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py18
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py85
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/mnist/mnist.py22
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/README.md2
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py7
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/README.md2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/data.py10
-rw-r--r--tensorflow/contrib/eager/python/g3doc/guide.md2
-rw-r--r--tensorflow/contrib/eager/python/network_test.py2
-rw-r--r--tensorflow/contrib/eager/python/saver.py2
-rw-r--r--tensorflow/contrib/estimator/BUILD3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py160
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py84
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py14
-rw-r--r--tensorflow/contrib/ffmpeg/decode_audio_op.cc20
-rw-r--r--tensorflow/contrib/ffmpeg/decode_audio_op_test.py19
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc28
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_lib.h8
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py7
-rw-r--r--tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4bin0 -> 69357 bytes
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py15
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h6
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc7
-rw-r--r--tensorflow/contrib/gan/BUILD5
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py33
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py69
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py61
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py29
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py31
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py38
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py33
-rw-r--r--tensorflow/contrib/gan/python/train.py169
-rw-r--r--tensorflow/contrib/gan/python/train_test.py78
-rw-r--r--tensorflow/contrib/gdr/BUILD2
-rw-r--r--tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc2
-rw-r--r--tensorflow/contrib/gdr/gdr_worker.cc13
-rw-r--r--tensorflow/contrib/gdr/gdr_worker.h2
-rw-r--r--tensorflow/contrib/hvx/README.md137
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h6
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py36
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py7
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/core_test.py2
-rw-r--r--tensorflow/contrib/layers/__init__.py1
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py150
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py688
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py12
-rw-r--r--tensorflow/contrib/learn/BUILD2
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/__init__.py23
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/synthetic.py66
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py262
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py223
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimators_test.py32
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py102
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export_test.py34
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc_test.py49
-rw-r--r--tensorflow/contrib/lite/allocation.h6
-rw-r--r--tensorflow/contrib/lite/arena_planner.h6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl10
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h38
-rw-r--r--tensorflow/contrib/lite/context.h6
-rwxr-xr-xtensorflow/contrib/lite/download_dependencies.sh9
-rw-r--r--tensorflow/contrib/lite/error_reporter.h6
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm31
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc55
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h7
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.md12
-rw-r--r--tensorflow/contrib/lite/graph_info.h6
-rw-r--r--tensorflow/contrib/lite/interpreter.cc1
-rw-r--r--tensorflow/contrib/lite/interpreter.h6
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD28
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h6
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc29
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD25
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h54
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h117
-rw-r--r--tensorflow/contrib/lite/kernels/internal/round.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h6
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h22
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h6
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc99
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h6
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/register.h6
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc44
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc29
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze_test.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc256
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice_test.cc375
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h13
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc527
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc1089
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc139
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc102
-rw-r--r--tensorflow/contrib/lite/memory_planner.h6
-rw-r--r--tensorflow/contrib/lite/model.cc61
-rw-r--r--tensorflow/contrib/lite/model.h6
-rw-r--r--tensorflow/contrib/lite/model_test.cc9
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h6
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc8
-rw-r--r--tensorflow/contrib/lite/models/speech_asr_am_model_test.cc127
-rw-r--r--tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc122
-rw-r--r--tensorflow/contrib/lite/models/speech_endpointer_model_test.cc104
-rw-r--r--tensorflow/contrib/lite/models/speech_hotword_model_test.cc114
-rw-r--r--tensorflow/contrib/lite/models/speech_speakerid_model_test.cc121
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc189
-rw-r--r--tensorflow/contrib/lite/models/speech_tts_model_test.cc116
-rw-r--r--tensorflow/contrib/lite/models/test_utils.h84
-rw-r--r--tensorflow/contrib/lite/models/testdata/g3doc/README.md4
-rw-r--r--tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec202
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h6
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h6
-rw-r--r--tensorflow/contrib/lite/python/lite.py13
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs26
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h470
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h6
-rw-r--r--tensorflow/contrib/lite/string_util.h6
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py296
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc57
-rw-r--r--tensorflow/contrib/lite/testing/message.h6
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.cc10
-rw-r--r--tensorflow/contrib/lite/testing/parse_testdata.h6
-rw-r--r--tensorflow/contrib/lite/testing/split.h6
-rw-r--r--tensorflow/contrib/lite/testing/test_runner.h6
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h6
-rw-r--r--tensorflow/contrib/lite/testing/tokenize.h6
-rw-r--r--tensorflow/contrib/lite/testing/util.h6
-rw-r--r--tensorflow/contrib/lite/toco/BUILD3
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc12
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.h6
-rw-r--r--tensorflow/contrib/lite/toco/args.h11
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc8
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.h6
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc61
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.h6
-rw-r--r--tensorflow/contrib/lite/toco/format_port.h6
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc149
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc51
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc199
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h6
-rw-r--r--tensorflow/contrib/lite/toco/model.h71
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.h7
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto37
-rw-r--r--tensorflow/contrib/lite/toco/runtime/common.h6
-rw-r--r--tensorflow/contrib/lite/toco/runtime/types.h6
-rw-r--r--tensorflow/contrib/lite/toco/tensorflow_util.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD10
-rw-r--r--tensorflow/contrib/lite/toco/tflite/builtin_operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/custom_operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc38
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/simple_operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.h6
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.h6
-rw-r--r--tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h6
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h27
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc7
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.h6
-rw-r--r--tensorflow/contrib/lite/toco/toco_types.h6
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc100
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h33
-rw-r--r--tensorflow/contrib/lite/tools/BUILD27
-rw-r--r--tensorflow/contrib/lite/tools/gen_op_registration.h6
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h6
-rw-r--r--tensorflow/contrib/lite/tools/verifier.cc43
-rw-r--r--tensorflow/contrib/lite/tools/verifier.h31
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc136
-rw-r--r--tensorflow/contrib/lite/version.h6
-rw-r--r--tensorflow/contrib/makefile/Makefile6
-rwxr-xr-xtensorflow/contrib/makefile/build_all_android.sh4
-rwxr-xr-xtensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh4
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py77
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py933
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py86
-rw-r--r--tensorflow/contrib/mpi/BUILD2
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc8
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h6
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h6
-rw-r--r--tensorflow/contrib/ndlstm/python/lstm1d.py12
-rw-r--r--tensorflow/contrib/nearest_neighbor/kernels/heap.h6
-rw-r--r--tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h6
-rw-r--r--tensorflow/contrib/opt/BUILD22
-rw-r--r--tensorflow/contrib/opt/__init__.py5
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py115
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py99
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py5
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py308
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py198
-rw-r--r--tensorflow/contrib/opt/python/training/nadam_optimizer.py15
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD18
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h11
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc42
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py55
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py4
-rw-r--r--tensorflow/contrib/py2tf/BUILD4
-rw-r--r--tensorflow/contrib/py2tf/api.py158
-rw-r--r--tensorflow/contrib/py2tf/api_test.py140
-rw-r--r--tensorflow/contrib/py2tf/config.py1
-rw-r--r--tensorflow/contrib/py2tf/conversion.py177
-rw-r--r--tensorflow/contrib/py2tf/conversion_test.py19
-rw-r--r--tensorflow/contrib/py2tf/converters/BUILD (renamed from tensorflow/contrib/py2tf/convert/BUILD)44
-rw-r--r--tensorflow/contrib/py2tf/converters/__init__.py (renamed from tensorflow/contrib/py2tf/convert/__init__.py)0
-rw-r--r--tensorflow/contrib/py2tf/converters/break_canonicalization.py (renamed from tensorflow/contrib/py2tf/convert/break_canonicalization.py)26
-rw-r--r--tensorflow/contrib/py2tf/converters/break_canonicalization_test.py (renamed from tensorflow/contrib/py2tf/convert/break_canonicalization_test.py)20
-rw-r--r--tensorflow/contrib/py2tf/converters/builtin_functions.py (renamed from tensorflow/contrib/py2tf/convert/builtin_functions.py)7
-rw-r--r--tensorflow/contrib/py2tf/converters/builtin_functions_test.py (renamed from tensorflow/contrib/py2tf/convert/builtin_functions_test.py)18
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees.py (renamed from tensorflow/contrib/py2tf/convert/call_trees.py)128
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees_test.py (renamed from tensorflow/contrib/py2tf/convert/call_trees_test.py)26
-rw-r--r--tensorflow/contrib/py2tf/converters/continue_canonicalization.py (renamed from tensorflow/contrib/py2tf/convert/continue_canonicalization.py)22
-rw-r--r--tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py (renamed from tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py)20
-rw-r--r--tensorflow/contrib/py2tf/converters/control_flow.py (renamed from tensorflow/contrib/py2tf/convert/control_flow.py)96
-rw-r--r--tensorflow/contrib/py2tf/converters/control_flow_test.py (renamed from tensorflow/contrib/py2tf/convert/control_flow_test.py)24
-rw-r--r--tensorflow/contrib/py2tf/converters/converter_test_base.py48
-rw-r--r--tensorflow/contrib/py2tf/converters/decorators.py56
-rw-r--r--tensorflow/contrib/py2tf/converters/for_canonicalization.py (renamed from tensorflow/contrib/py2tf/convert/for_canonicalization.py)28
-rw-r--r--tensorflow/contrib/py2tf/converters/for_canonicalization_test.py (renamed from tensorflow/contrib/py2tf/convert/for_canonicalization_test.py)16
-rw-r--r--tensorflow/contrib/py2tf/converters/logical_expressions.py (renamed from tensorflow/contrib/py2tf/convert/logical_expressions.py)0
-rw-r--r--tensorflow/contrib/py2tf/converters/logical_expressions_test.py (renamed from tensorflow/contrib/py2tf/convert/logical_expressions_test.py)10
-rw-r--r--tensorflow/contrib/py2tf/converters/print_functions.py (renamed from tensorflow/contrib/py2tf/convert/print_functions.py)0
-rw-r--r--tensorflow/contrib/py2tf/converters/print_functions_test.py (renamed from tensorflow/contrib/py2tf/convert/print_functions_test.py)18
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards.py (renamed from tensorflow/contrib/py2tf/convert/side_effect_guards.py)28
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards_test.py (renamed from tensorflow/contrib/py2tf/convert/side_effect_guards_test.py)18
-rw-r--r--tensorflow/contrib/py2tf/naming.py9
-rw-r--r--tensorflow/contrib/py2tf/naming_test.py12
-rw-r--r--tensorflow/contrib/py2tf/pyct/BUILD2
-rw-r--r--tensorflow/contrib/py2tf/pyct/context.py42
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py18
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py70
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py55
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates_test.py36
-rw-r--r--tensorflow/contrib/py2tf/pyct/transformer.py58
-rw-r--r--tensorflow/contrib/quantize/__init__.py2
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py64
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py56
-rw-r--r--tensorflow/contrib/receptive_field/python/util/graph_compute_order.py2
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h6
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc11
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops.h7
-rw-r--r--tensorflow/contrib/rnn/kernels/blas_gemm.h6
-rw-r--r--tensorflow/contrib/rnn/kernels/gru_ops.h6
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops.h6
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py6
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py1000
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py814
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h6
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.h6
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py121
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py116
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py36
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.h6
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.py8
-rw-r--r--tensorflow/contrib/session_bundle/constants.py3
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.h6
-rw-r--r--tensorflow/contrib/session_bundle/signature.h6
-rw-r--r--tensorflow/contrib/session_bundle/test_util.h6
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py51
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py35
-rw-r--r--tensorflow/contrib/solvers/python/ops/linear_equations.py52
-rw-r--r--tensorflow/contrib/solvers/python/ops/util.py17
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py34
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/data_spec.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h10
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc2
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_target.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/params.h7
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h8
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h6
-rw-r--r--tensorflow/contrib/tensorrt/BUILD38
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc159
-rw-r--r--tensorflow/contrib/tpu/profiler/BUILD11
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc15
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc7
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.h6
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py3
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py9
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto5
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto13
-rw-r--r--tensorflow/contrib/tpu/profiler/trace_events_to_json.h6
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h21
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py58
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py464
-rw-r--r--tensorflow/contrib/tpu/tpu_estimator.md241
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py30
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py2
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py2
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h6
-rw-r--r--tensorflow/contrib/verbs/BUILD6
-rw-r--r--tensorflow/contrib/verbs/README.md174
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_client.h6
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.cc12
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.h6
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h6
-rw-r--r--tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md87
-rw-r--r--tensorflow/contrib/verbs/rdma.cc1374
-rw-r--r--tensorflow/contrib/verbs/rdma.h519
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc228
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h7
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc114
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.h6
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc1
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.h6
-rw-r--r--tensorflow/contrib/verbs/verbs_service.proto6
-rw-r--r--tensorflow/contrib/verbs/verbs_with_0_copies.pngbin0 -> 62862 bytes
-rw-r--r--tensorflow/contrib/verbs/verbs_with_0_copies.xml1
-rw-r--r--tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpgbin0 -> 88799 bytes
-rw-r--r--tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml1
-rw-r--r--tensorflow/core/BUILD24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Batch.pbtxt42
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorGetNextSync.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorGetItem.pbtxt11
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListElementShape.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListReserve.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorSetItem.pbtxt11
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Unbatch.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnbatchGrad.pbtxt20
-rw-r--r--tensorflow/core/api_def/update_api_def.cc8
-rw-r--r--tensorflow/core/api_def/update_api_def.h8
-rw-r--r--tensorflow/core/api_def/update_api_def_test.cc32
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/function_testlib.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h6
-rw-r--r--tensorflow/core/common_runtime/memory_types.h6
-rw-r--r--tensorflow/core/common_runtime/pending_counts.h6
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h6
-rw-r--r--tensorflow/core/common_runtime/profile_handler.h6
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h6
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.h6
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.h6
-rw-r--r--tensorflow/core/common_runtime/stats_publisher_interface.h6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD45
-rw-r--r--tensorflow/core/distributed_runtime/cluster_function_library_runtime.h6
-rw-r--r--tensorflow/core/distributed_runtime/local_master.h6
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc4
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h6
-rw-r--r--tensorflow/core/distributed_runtime/partial_run_mgr.h6
-rw-r--r--tensorflow/core/distributed_runtime/recent_request_ids.cc57
-rw-r--r--tensorflow/core/distributed_runtime/recent_request_ids.h72
-rw-r--r--tensorflow/core/distributed_runtime/recent_request_ids_test.cc96
-rw-r--r--tensorflow/core/distributed_runtime/request_id.cc30
-rw-r--r--tensorflow/core/distributed_runtime/request_id.h31
-rw-r--r--tensorflow/core/distributed_runtime/request_id_test.cc29
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/async_service_interface.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_call.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_state.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_util.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h13
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/server_lib.h6
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.cc82
-rw-r--r--tensorflow/core/distributed_runtime/session_mgr.h15
-rw-r--r--tensorflow/core/distributed_runtime/worker.h6
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.h6
-rw-r--r--tensorflow/core/example/example_parser_configuration.h6
-rw-r--r--tensorflow/core/framework/common_shape_fns.h6
-rw-r--r--tensorflow/core/framework/dataset.h75
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc29
-rw-r--r--tensorflow/core/framework/op_kernel.cc43
-rw-r--r--tensorflow/core/framework/op_kernel.h68
-rw-r--r--tensorflow/core/framework/register_types.h2
-rw-r--r--tensorflow/core/framework/shape_inference.h6
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h6
-rw-r--r--tensorflow/core/framework/tensor.cc4
-rw-r--r--tensorflow/core/framework/tensor.h6
-rw-r--r--tensorflow/core/framework/types.h7
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc24
-rw-r--r--tensorflow/core/framework/variant_op_registry.h44
-rw-r--r--tensorflow/core/graph/costmodel.cc53
-rw-r--r--tensorflow/core/graph/costmodel.h4
-rw-r--r--tensorflow/core/graph/gradients.h6
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/costs/measuring_cost_estimator.cc4
-rw-r--r--tensorflow/core/grappler/costs/op_context.h6
-rw-r--r--tensorflow/core/grappler/costs/op_performance_data.proto7
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h10
-rw-r--r--tensorflow/core/grappler/grappler_item.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc72
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h8
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc46
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc87
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.h6
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc43
-rw-r--r--tensorflow/core/grappler/optimizers/static_schedule.h6
-rw-r--r--tensorflow/core/grappler/utils/frame.h6
-rw-r--r--tensorflow/core/grappler/utils/scc.h6
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.h6
-rw-r--r--tensorflow/core/kernels/BUILD30
-rw-r--r--tensorflow/core/kernels/adjust_hsv_gpu.cu.h6
-rw-r--r--tensorflow/core/kernels/batch_kernels.cc (renamed from tensorflow/contrib/batching/kernels/batch_kernels.cc)6
-rw-r--r--tensorflow/core/kernels/batch_util.h6
-rw-r--r--tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h7
-rw-r--r--tensorflow/core/kernels/batching_util/basic_batch_scheduler.h6
-rw-r--r--tensorflow/core/kernels/batching_util/batch_scheduler.h6
-rw-r--r--tensorflow/core/kernels/batching_util/fake_clock_env.h6
-rw-r--r--tensorflow/core/kernels/batching_util/periodic_function.h7
-rw-r--r--tensorflow/core/kernels/batching_util/shared_batch_scheduler.h6
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/bitcast_op.h6
-rw-r--r--tensorflow/core/kernels/captured_function.h6
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h6
-rw-r--r--tensorflow/core/kernels/compare_and_bitpack_op.h6
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h26
-rw-r--r--tensorflow/core/kernels/cuda_device_array.h6
-rw-r--r--tensorflow/core/kernels/cuda_device_array_gpu.h6
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h2
-rw-r--r--tensorflow/core/kernels/cwise_op_pow.cc7
-rw-r--r--tensorflow/core/kernels/cwise_ops.h35
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc5
-rw-r--r--tensorflow/core/kernels/data/captured_function.h6
-rw-r--r--tensorflow/core/kernels/data/dataset.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h6
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc39
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.h6
-rw-r--r--tensorflow/core/kernels/data/sql/query_connection.h6
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.h6
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator.h6
-rw-r--r--tensorflow/core/kernels/data/window_dataset.h6
-rw-r--r--tensorflow/core/kernels/dataset.h6
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc19
-rw-r--r--tensorflow/core/kernels/decode_image_op.cc39
-rw-r--r--tensorflow/core/kernels/deep_conv2d.h6
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.h6
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc13
-rw-r--r--tensorflow/core/kernels/determinant_op.h6
-rw-r--r--tensorflow/core/kernels/eigen_activations.h6
-rw-r--r--tensorflow/core/kernels/eigen_attention.h6
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h6
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h4
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h6
-rw-r--r--tensorflow/core/kernels/eigen_pooling.h6
-rw-r--r--tensorflow/core/kernels/eigen_softmax.h6
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h6
-rw-r--r--tensorflow/core/kernels/eigen_volume_patch.h6
-rw-r--r--tensorflow/core/kernels/eye_functor.h6
-rw-r--r--tensorflow/core/kernels/fake_quant_ops_functor.h6
-rw-r--r--tensorflow/core/kernels/gather_functor_gpu.cu.h6
-rw-r--r--tensorflow/core/kernels/gpu_utils.h6
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h6
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h6
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h6
-rw-r--r--tensorflow/core/kernels/i_remote_fused_graph_executor.h6
-rw-r--r--tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h6
-rw-r--r--tensorflow/core/kernels/list_kernels.cc175
-rw-r--r--tensorflow/core/kernels/list_kernels.h36
-rw-r--r--tensorflow/core/kernels/meta_support.h6
-rw-r--r--tensorflow/core/kernels/mfcc.h6
-rw-r--r--tensorflow/core/kernels/mfcc_dct.h6
-rw-r--r--tensorflow/core/kernels/mfcc_mel_filterbank.h6
-rw-r--r--tensorflow/core/kernels/mirror_pad_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc186
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc33
-rw-r--r--tensorflow/core/kernels/neon/BUILD2
-rw-r--r--tensorflow/core/kernels/neon/depthwiseconv_float.h6
-rw-r--r--tensorflow/core/kernels/neon/types.h6
-rw-r--r--tensorflow/core/kernels/pack_op.cc1
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc2
-rw-r--r--tensorflow/core/kernels/population_count_op.h6
-rw-r--r--tensorflow/core/kernels/quantization_utils.h6
-rw-r--r--tensorflow/core/kernels/reference_gemm.h6
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h6
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.h6
-rw-r--r--tensorflow/core/kernels/reshape_util.h6
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc21
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h6
-rw-r--r--tensorflow/core/kernels/slice_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/spectrogram.h6
-rw-r--r--tensorflow/core/kernels/spectrogram_test_utils.cc14
-rw-r--r--tensorflow/core/kernels/spectrogram_test_utils.h6
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc43
-rw-r--r--tensorflow/core/kernels/svd_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/tile_ops_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/tile_ops_gpu_impl.h6
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc12
-rw-r--r--tensorflow/core/kernels/transpose_functor_gpu.cu.cc21
-rw-r--r--tensorflow/core/kernels/winograd_transform.h6
-rw-r--r--tensorflow/core/kernels/xent_op.cc10
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.h6
-rw-r--r--tensorflow/core/lib/core/bitmap.h6
-rw-r--r--tensorflow/core/lib/gif/gif_io.cc16
-rw-r--r--tensorflow/core/lib/gif/gif_io.h3
-rw-r--r--tensorflow/core/lib/gtl/compactptrset.h6
-rw-r--r--tensorflow/core/lib/gtl/flatmap.h6
-rw-r--r--tensorflow/core/lib/gtl/flatrep.h6
-rw-r--r--tensorflow/core/lib/gtl/flatset.h6
-rw-r--r--tensorflow/core/lib/io/buffered_inputstream.h2
-rw-r--r--tensorflow/core/lib/io/compression.h6
-rw-r--r--tensorflow/core/lib/io/inputstream_interface.h2
-rw-r--r--tensorflow/core/lib/io/random_inputstream.cc37
-rw-r--r--tensorflow/core/lib/io/random_inputstream.h2
-rw-r--r--tensorflow/core/lib/io/snappy/snappy_outputbuffer.h6
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h6
-rw-r--r--tensorflow/core/lib/monitoring/collected_metrics.h6
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h6
-rw-r--r--tensorflow/core/lib/monitoring/counter.h6
-rw-r--r--tensorflow/core/lib/monitoring/gauge.h6
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h6
-rw-r--r--tensorflow/core/lib/monitoring/mobile_counter.h6
-rw-r--r--tensorflow/core/lib/monitoring/mobile_gauge.h6
-rw-r--r--tensorflow/core/lib/monitoring/mobile_sampler.h6
-rw-r--r--tensorflow/core/lib/monitoring/sampler.h6
-rw-r--r--tensorflow/core/lib/strings/proto_text_util.h6
-rw-r--r--tensorflow/core/ops/batch_ops.cc84
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt736
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc55
-rw-r--r--tensorflow/core/ops/dataset_ops.cc68
-rw-r--r--tensorflow/core/ops/image_ops.cc42
-rw-r--r--tensorflow/core/ops/io_ops.cc16
-rw-r--r--tensorflow/core/ops/list_ops.cc76
-rw-r--r--tensorflow/core/ops/ops.pbtxt322
-rw-r--r--tensorflow/core/ops/training_ops.cc42
-rw-r--r--tensorflow/core/platform/cloud/BUILD4
-rw-r--r--tensorflow/core/platform/cloud/gcs_dns_cache.h6
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc49
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h14
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc424
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.h6
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.h6
-rw-r--r--tensorflow/core/platform/cloud/time_util.h6
-rw-r--r--tensorflow/core/platform/cuda_libdevice_path.h6
-rw-r--r--tensorflow/core/platform/cupti_wrapper.h6
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD4
-rw-r--r--tensorflow/core/platform/default/gpu/cupti_wrapper.h6
-rw-r--r--tensorflow/core/platform/demangle.h6
-rw-r--r--tensorflow/core/platform/file_statistics.h6
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.h6
-rw-r--r--tensorflow/core/platform/profile_utils/cpu_utils.h4
-rw-r--r--tensorflow/core/platform/s3/BUILD12
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc143
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.h9
-rw-r--r--tensorflow/core/platform/s3/s3_file_system_test.cc2
-rw-r--r--tensorflow/core/platform/stacktrace_handler.h6
-rw-r--r--tensorflow/core/profiler/BUILD2
-rw-r--r--tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h6
-rw-r--r--tensorflow/core/profiler/internal/advisor/checker.h6
-rw-r--r--tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h6
-rw-r--r--tensorflow/core/profiler/internal/advisor/internal_checker_runner.h6
-rw-r--r--tensorflow/core/profiler/internal/advisor/operation_checker.h6
-rw-r--r--tensorflow/core/profiler/internal/advisor/tfprof_advisor.h6
-rw-r--r--tensorflow/core/profiler/internal/print_model_analysis.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_code.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_constants.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_graph.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node_show.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_op.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_scope.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_show.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_show_multi.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_stats.h10
-rw-r--r--tensorflow/core/profiler/internal/tfprof_tensor.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_timeline.h7
-rw-r--r--tensorflow/core/profiler/internal/tfprof_utils.cc3
-rw-r--r--tensorflow/core/profiler/internal/tfprof_utils.h6
-rw-r--r--tensorflow/core/profiler/tfprof_options.h6
-rw-r--r--tensorflow/core/protobuf/worker.proto15
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/command_line_flags.h6
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_entry.h51
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h17
-rw-r--r--tensorflow/core/util/ctc/ctc_decoder.h3
-rw-r--r--tensorflow/core/util/cuda_device_functions.h499
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h857
-rw-r--r--tensorflow/core/util/cuda_kernel_helper_test.cu.cc60
-rw-r--r--tensorflow/core/util/cuda_launch_config.h284
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h6
-rw-r--r--tensorflow/core/util/example_proto_helper.h6
-rw-r--r--tensorflow/core/util/matmul_autotune.h6
-rw-r--r--tensorflow/core/util/strided_slice_op.h6
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.signal.md6
-rw-r--r--tensorflow/docs_src/api_guides/python/python_io.md10
-rw-r--r--tensorflow/docs_src/install/install_mac.md4
-rw-r--r--tensorflow/docs_src/programmers_guide/saved_model.md17
-rw-r--r--tensorflow/examples/android/BUILD2
-rw-r--r--tensorflow/examples/android/download-models.gradle2
-rw-r--r--tensorflow/examples/android/jni/object_tracking/config.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/flow_cache.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/frame_pair.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/geom.h6
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/gl_utils.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image-inl.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_data.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/image_utils.h6
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/integral_image.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/jni_utils.h4
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/keypoint_detector.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/logging.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_detector.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_model.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/object_tracker.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/optical_flow.h6
-rwxr-xr-xtensorflow/examples/android/jni/object_tracking/sprite.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/time_log.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/tracked_object.h6
-rw-r--r--tensorflow/examples/android/jni/object_tracking/utils.h6
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java7
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py49
-rw-r--r--tensorflow/examples/label_image/label_image.py41
-rw-r--r--tensorflow/examples/speech_commands/accuracy_utils.h6
-rw-r--r--tensorflow/examples/speech_commands/recognize_commands.h6
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py161
-rw-r--r--tensorflow/examples/udacity/Dockerfile2
-rw-r--r--tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h6
-rw-r--r--tensorflow/go/op/wrappers.go1282
-rw-r--r--tensorflow/python/BUILD8
-rw-r--r--tensorflow/python/client/session.py185
-rw-r--r--tensorflow/python/client/session_test.py310
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py109
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py44
-rw-r--r--tensorflow/python/debug/lib/debug_gradients_test.py2
-rw-r--r--tensorflow/python/eager/BUILD19
-rw-r--r--tensorflow/python/eager/benchmarks_test.py36
-rw-r--r--tensorflow/python/eager/context.py18
-rw-r--r--tensorflow/python/eager/core_test.py18
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/function_test.py5
-rw-r--r--tensorflow/python/eager/ops_test.py10
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.h6
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h22
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc241
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py109
-rw-r--r--tensorflow/python/eager/tensor_test.py13
-rw-r--r--tensorflow/python/estimator/BUILD7
-rw-r--r--tensorflow/python/estimator/canned/dnn.py47
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py51
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py2
-rw-r--r--tensorflow/python/estimator/canned/head.py249
-rw-r--r--tensorflow/python/estimator/canned/head_test.py437
-rw-r--r--tensorflow/python/estimator/canned/linear.py51
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py2
-rw-r--r--tensorflow/python/estimator/estimator.py44
-rw-r--r--tensorflow/python/estimator/estimator_test.py28
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py101
-rw-r--r--tensorflow/python/estimator/run_config.py18
-rw-r--r--tensorflow/python/estimator/warm_starting_util.py6
-rw-r--r--tensorflow/python/feature_column/BUILD3
-rw-r--r--tensorflow/python/feature_column/feature_column.py10
-rw-r--r--tensorflow/python/feature_column/feature_column_lib.py1
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py59
-rw-r--r--tensorflow/python/framework/constant_op.py3
-rw-r--r--tensorflow/python/framework/cpp_shape_inference.h6
-rw-r--r--tensorflow/python/framework/device.py2
-rw-r--r--tensorflow/python/framework/dtypes.py27
-rw-r--r--tensorflow/python/framework/errors_impl.py45
-rw-r--r--tensorflow/python/framework/function.py4
-rw-r--r--tensorflow/python/framework/function_test.py17
-rw-r--r--tensorflow/python/framework/graph_io.py2
-rw-r--r--tensorflow/python/framework/graph_util_impl.py6
-rw-r--r--tensorflow/python/framework/importer.py2
-rw-r--r--tensorflow/python/framework/load_library.py3
-rw-r--r--tensorflow/python/framework/ops.py214
-rw-r--r--tensorflow/python/framework/python_op_gen_internal.h6
-rw-r--r--tensorflow/python/framework/random_seed.py3
-rw-r--r--tensorflow/python/framework/sparse_tensor.py4
-rw-r--r--tensorflow/python/framework/tensor_shape.py3
-rw-r--r--tensorflow/python/framework/tensor_util.py3
-rw-r--r--tensorflow/python/framework/test_util.py117
-rw-r--r--tensorflow/python/framework/test_util_test.py34
-rw-r--r--tensorflow/python/framework/versions.py9
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py41
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py68
-rw-r--r--tensorflow/python/grappler/tf_optimizer.i5
-rwxr-xr-xtensorflow/python/keras/BUILD32
-rw-r--r--tensorflow/python/keras/_impl/keras/__init__.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/activations.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/__init__.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/densenet.py346
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/densenet_test.py101
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py156
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py19
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_v3.py21
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py70
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet.py783
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet_test.py76
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py51
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg16.py60
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg19.py73
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/xception.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py289
-rw-r--r--tensorflow/python/keras/_impl/keras/backend_test.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/callbacks.py132
-rw-r--r--tensorflow/python/keras/_impl/keras/constraints.py22
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/boston_housing.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar10.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar100.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py7
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/imdb.py60
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/mnist.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/reuters.py57
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology.py61
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py588
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py87
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations.py68
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py141
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py44
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_test.py66
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py26
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/local.py110
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/merge.py110
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/noise.py49
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py619
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py99
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py58
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers_test.py125
-rw-r--r--tensorflow/python/keras/_impl/keras/losses.py20
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/models.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/models_test.py29
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers.py125
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers_test.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/image.py164
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/sequence.py25
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text.py51
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text_test.py16
-rw-r--r--tensorflow/python/keras/_impl/keras/regularizers.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/data_utils.py209
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/io_utils.py16
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/layer_utils.py33
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/np_utils.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/training_utils.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/vis_utils.py33
-rw-r--r--tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py71
-rw-r--r--tensorflow/python/keras/applications/__init__.py7
-rw-r--r--tensorflow/python/keras/applications/densenet/__init__.py29
-rw-r--r--tensorflow/python/keras/applications/nasnet/__init__.py28
-rw-r--r--tensorflow/python/keras/layers/__init__.py3
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py108
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/conv1d_test.py44
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py27
-rw-r--r--tensorflow/python/kernel_tests/diag_op_test.py225
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py4
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py37
-rw-r--r--tensorflow/python/kernel_tests/map_stage_op_test.py105
-rw-r--r--tensorflow/python/kernel_tests/metrics_test.py29
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py72
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py23
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py20
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py5
-rw-r--r--tensorflow/python/kernel_tests/scalar_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_slice_op_test.py102
-rw-r--r--tensorflow/python/kernel_tests/stage_op_test.py34
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py54
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py42
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py10
-rw-r--r--tensorflow/python/layers/base.py139
-rw-r--r--tensorflow/python/layers/convolutional.py452
-rw-r--r--tensorflow/python/layers/convolutional_test.py162
-rw-r--r--tensorflow/python/layers/layers.py4
-rw-r--r--tensorflow/python/layers/maxout.py34
-rw-r--r--tensorflow/python/layers/network.py7
-rw-r--r--tensorflow/python/layers/pooling.py18
-rw-r--r--tensorflow/python/layers/pooling_test.py24
-rw-r--r--tensorflow/python/layers/utils.py2
-rw-r--r--tensorflow/python/lib/core/bfloat16_test.py19
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.h6
-rw-r--r--tensorflow/python/lib/core/py_func.cc19
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.h6
-rw-r--r--tensorflow/python/lib/core/safe_ptr.h6
-rw-r--r--tensorflow/python/ops/array_grad.py116
-rw-r--r--tensorflow/python/ops/array_ops.py52
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py7
-rw-r--r--tensorflow/python/ops/check_ops.py23
-rw-r--r--tensorflow/python/ops/clip_ops.py8
-rw-r--r--tensorflow/python/ops/confusion_matrix.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py488
-rw-r--r--tensorflow/python/ops/ctc_ops.py4
-rw-r--r--tensorflow/python/ops/data_flow_ops.py399
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py2
-rw-r--r--tensorflow/python/ops/distributions/beta.py2
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py2
-rw-r--r--tensorflow/python/ops/distributions/categorical.py6
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py9
-rw-r--r--tensorflow/python/ops/distributions/exponential.py2
-rw-r--r--tensorflow/python/ops/distributions/gamma.py2
-rw-r--r--tensorflow/python/ops/distributions/identity_bijector.py2
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py3
-rw-r--r--tensorflow/python/ops/distributions/laplace.py2
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/normal.py2
-rw-r--r--tensorflow/python/ops/distributions/student_t.py2
-rw-r--r--tensorflow/python/ops/distributions/uniform.py2
-rw-r--r--tensorflow/python/ops/embedding_ops.py3
-rw-r--r--tensorflow/python/ops/functional_ops.py7
-rw-r--r--tensorflow/python/ops/gradient_checker.py3
-rw-r--r--tensorflow/python/ops/gradients_impl.py95
-rw-r--r--tensorflow/python/ops/histogram_ops.py72
-rw-r--r--tensorflow/python/ops/histogram_ops_test.py57
-rw-r--r--tensorflow/python/ops/image_ops.py4
-rw-r--r--tensorflow/python/ops/image_ops_impl.py424
-rw-r--r--tensorflow/python/ops/image_ops_test.py87
-rw-r--r--tensorflow/python/ops/init_ops.py22
-rw-r--r--tensorflow/python/ops/io_ops.py8
-rw-r--r--tensorflow/python/ops/linalg_grad.py59
-rw-r--r--tensorflow/python/ops/linalg_ops.py8
-rw-r--r--tensorflow/python/ops/list_ops.py43
-rw-r--r--tensorflow/python/ops/logging_ops.py2
-rw-r--r--tensorflow/python/ops/lookup_ops.py3
-rw-r--r--tensorflow/python/ops/math_ops.py65
-rw-r--r--tensorflow/python/ops/metrics_impl.py616
-rw-r--r--tensorflow/python/ops/nn_batchnorm_test.py15
-rw-r--r--tensorflow/python/ops/nn_grad.py300
-rw-r--r--tensorflow/python/ops/nn_grad_test.py13
-rw-r--r--tensorflow/python/ops/nn_impl.py60
-rw-r--r--tensorflow/python/ops/nn_ops.py34
-rw-r--r--tensorflow/python/ops/nn_test.py29
-rw-r--r--tensorflow/python/ops/numerics.py3
-rw-r--r--tensorflow/python/ops/parsing_ops.py9
-rw-r--r--tensorflow/python/ops/partitioned_variables.py5
-rw-r--r--tensorflow/python/ops/random_ops.py9
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py38
-rw-r--r--tensorflow/python/ops/rnn.py33
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py15
-rw-r--r--tensorflow/python/ops/script_ops.py2
-rw-r--r--tensorflow/python/ops/session_ops.py4
-rw-r--r--tensorflow/python/ops/sets_impl.py5
-rw-r--r--tensorflow/python/ops/sparse_ops.py27
-rw-r--r--tensorflow/python/ops/special_math_ops.py119
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py74
-rw-r--r--tensorflow/python/ops/spectral_ops.py8
-rw-r--r--tensorflow/python/ops/state_ops.py7
-rw-r--r--tensorflow/python/ops/string_ops.py3
-rw-r--r--tensorflow/python/ops/summary_ops.py2
-rw-r--r--tensorflow/python/ops/template.py2
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py2
-rw-r--r--tensorflow/python/ops/variable_scope.py192
-rw-r--r--tensorflow/python/ops/variables.py80
-rw-r--r--tensorflow/python/platform/benchmark.py2
-rw-r--r--tensorflow/python/platform/gfile.py3
-rw-r--r--tensorflow/python/platform/sysconfig.py5
-rw-r--r--tensorflow/python/platform/test.py5
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py10
-rw-r--r--tensorflow/python/pywrap_tfe.i12
-rw-r--r--tensorflow/python/summary/summary.py5
-rw-r--r--tensorflow/python/tools/freeze_graph.py22
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py3
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py14
-rw-r--r--tensorflow/python/tools/saved_model_cli.py108
-rw-r--r--tensorflow/python/tools/saved_model_cli_test.py75
-rw-r--r--tensorflow/python/training/adam.py62
-rw-r--r--tensorflow/python/training/adam_test.py7
-rw-r--r--tensorflow/python/training/checkpoint_utils.py5
-rw-r--r--tensorflow/python/training/coordinator_test.py70
-rw-r--r--tensorflow/python/training/moving_averages.py4
-rw-r--r--tensorflow/python/training/optimizer.py29
-rw-r--r--tensorflow/python/training/saver.py11
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py5
-rw-r--r--tensorflow/python/util/compat.py18
-rw-r--r--tensorflow/python/util/compat_internal.py34
-rw-r--r--tensorflow/python/util/kernel_registry.h6
-rw-r--r--tensorflow/python/util/nest.py87
-rw-r--r--tensorflow/python/util/nest_test.py30
-rw-r--r--tensorflow/python/util/tf_inspect.py11
-rw-r--r--tensorflow/python/util/util.h6
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.cc2
-rw-r--r--tensorflow/stream_executor/dso_loader.cc8
-rw-r--r--tensorflow/tensorflow.bzl23
-rw-r--r--tensorflow/tools/api/generator/BUILD29
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py33
-rw-r--r--tensorflow/tools/api/golden/tensorflow.compat.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt23
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt186
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt186
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt183
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt144
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.pbtxt2
-rw-r--r--tensorflow/tools/benchmark/BUILD3
-rw-r--r--tensorflow/tools/benchmark/README.md1
-rwxr-xr-xtensorflow/tools/ci_build/builds/libtensorflow.sh4
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh47
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_cc_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py2_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_cc_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_py3_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_contrib.sh3
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh3
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py110
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py4
-rw-r--r--tensorflow/tools/dist_test/README.md2
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh31
-rw-r--r--tensorflow/tools/docs/pretty_docs.py2
-rw-r--r--tensorflow/tools/graph_transforms/file_utils.h6
-rw-r--r--tensorflow/tools/pip_package/BUILD24
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh2
-rw-r--r--tensorflow/tools/pip_package/check_load_py_test.py3
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py32
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib.h6
-rw-r--r--tensorflow/tools/test/BUILD9
-rw-r--r--tensorflow/tools/test/check_futures_test.py2
-rw-r--r--tensorflow/workspace.bzl22
-rw-r--r--third_party/aws.BUILD2
-rw-r--r--third_party/eigen3/BUILD2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h6
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h6
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h6
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h6
-rw-r--r--third_party/fft2d/fft.h6
-rw-r--r--third_party/gpus/cuda_configure.bzl104
-rw-r--r--third_party/jpeg/jpeg.BUILD53
-rw-r--r--third_party/swig.BUILD2
-rw-r--r--third_party/tensorrt/BUILD.tpl75
-rw-r--r--third_party/tensorrt/build_defs.bzl.tpl23
-rw-r--r--third_party/tensorrt/tensorrt_configure.bzl224
-rw-r--r--third_party/toolchains/clang6/BUILD1
-rw-r--r--third_party/toolchains/clang6/CROSSTOOL.tpl587
-rw-r--r--third_party/toolchains/clang6/README.md101
-rw-r--r--third_party/toolchains/clang6/clang.BUILD162
-rw-r--r--third_party/toolchains/clang6/repo.bzl30
1370 files changed, 40205 insertions, 15868 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 39fc46ac63..fdf10407fd 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -160,7 +160,7 @@ answered questions, and were part of inspiring discussions.
# Release 1.4.1
## Bug Fixes and Other Changes
-* `LinearClassifier` fix for CloudML Engine.
+* `LinearClassifier` fix.
# Release 1.4.0
diff --git a/WORKSPACE b/WORKSPACE
index 7ae39374f1..1e38a9a8cd 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -41,12 +41,12 @@ load("//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace()
new_http_archive(
- name = "inception5h",
+ name = "inception_v1",
build_file = "models.BUILD",
- sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364",
+ sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
urls = [
- "http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip",
- "http://download.tensorflow.org/models/inception5h.zip",
+ "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
+ "http://download.tensorflow.org/models/inception_v1.zip",
],
)
diff --git a/configure.py b/configure.py
index 580bbc0ebe..b621b1bc1b 100644
--- a/configure.py
+++ b/configure.py
@@ -959,6 +959,119 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
+def set_tf_tensorrt_install_path(environ_cp):
+ """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
+
+ Adapted from code contributed by Sami Kama (https://github.com/samikama).
+
+ Args:
+ environ_cp: copy of the os.environ.
+
+ Raises:
+ ValueError: if this method was called under non-Linux platform.
+ UserInputError: if user has provided invalid input multiple times.
+ """
+ if not is_linux():
+ raise ValueError('Currently TensorRT is only supported on Linux platform.')
+
+ # Ask user whether to add TensorRT support.
+ if str(int(get_var(
+ environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
+ return
+
+ for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
+ ask_tensorrt_path = (r'Please specify the location where TensorRT is '
+ 'installed. [Default is %s]:') % (
+ _DEFAULT_TENSORRT_PATH_LINUX)
+ trt_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
+ _DEFAULT_TENSORRT_PATH_LINUX)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ trt_install_path = os.path.realpath(
+ os.path.expanduser(trt_install_path))
+
+ def find_libs(search_path):
+ """Search for libnvinfer.so in "search_path"."""
+ fl = set()
+ if os.path.exists(search_path) and os.path.isdir(search_path):
+ fl.update([os.path.realpath(os.path.join(search_path, x))
+ for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+ return fl
+
+ possible_files = find_libs(trt_install_path)
+ possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
+ possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
+
+ def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
+ """Check the compatibility between tensorrt and cudnn/cudart libraries."""
+ ldd_bin = which('ldd') or '/usr/bin/ldd'
+ ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
+ cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
+ cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
+ cudnn = None
+ cudart = None
+ for line in ldd_out:
+ if 'libcudnn.so' in line:
+ cudnn = cudnn_pattern.search(line)
+ elif 'libcudart.so' in line:
+ cudart = cuda_pattern.search(line)
+ if cudnn and len(cudnn.group(1)):
+ cudnn = convert_version_to_int(cudnn.group(1))
+ if cudart and len(cudart.group(1)):
+ cudart = convert_version_to_int(cudart.group(1))
+ return (cudnn == cudnn_ver) and (cudart == cuda_ver)
+
+ cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
+ cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
+ nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
+ highest_ver = [0, None, None]
+
+ for lib_file in possible_files:
+ if is_compatible(lib_file, cuda_ver, cudnn_ver):
+ ver_str = nvinfer_pattern.search(lib_file).group(1)
+ ver = convert_version_to_int(ver_str) if len(ver_str) else 0
+ if ver > highest_ver[0]:
+ highest_ver = [ver, ver_str, lib_file]
+ if highest_ver[1] is not None:
+ trt_install_path = os.path.dirname(highest_ver[2])
+ tf_tensorrt_version = highest_ver[1]
+ break
+
+ # Try another alternative from ldconfig.
+ ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+ ldconfig_output = run_shell([ldconfig_bin, '-p'])
+ search_result = re.search(
+ '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
+ if search_result:
+ libnvinfer_path_from_ldconfig = search_result.group(2)
+ if os.path.exists(libnvinfer_path_from_ldconfig):
+ if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
+ trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
+ tf_tensorrt_version = search_result.group(1)
+ break
+
+ # Reset and Retry
+ print('Invalid path to TensorRT. None of the following files can be found:')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
+
+ else:
+ raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
+ 'times in a row. Assuming to be a scripting mistake.' %
+ _DEFAULT_PROMPT_ASK_ATTEMPTS)
+
+ # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
+ environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
+ write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
+ environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
+ write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
+
+
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@@ -1382,6 +1495,8 @@ def main():
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
+ if is_linux():
+ set_tf_tensorrt_install_path(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp)
set_tf_cuda_clang(environ_cp)
@@ -1440,6 +1555,7 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
+ config_info_line('tensorrt', 'Build with TensorRT support.')
if __name__ == '__main__':
main()
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index da37564697..b26c525525 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -370,12 +370,21 @@ config_setting(
visibility = ["//visibility:public"],
)
+# TODO(laigd): consider removing this option and make TensorRT enabled
+# automatically when CUDA is enabled.
+config_setting(
+ name = "with_tensorrt_support",
+ values = {"define": "with_tensorrt_support=true"},
+ visibility = ["//visibility:public"],
+)
+
package_group(
name = "internal",
packages = [
"//learning/meta_rank/...",
"//tensorflow/...",
"//tensorflow_fold/llgtm/...",
+ "//third_party/py/tensor2tensor/...",
],
)
@@ -441,9 +450,6 @@ filegroup(
"//tensorflow/contrib/all_reduce:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/batching:all_files",
- "//tensorflow/contrib/batching/kernels:all_files",
- "//tensorflow/contrib/batching/test_util:all_files",
- "//tensorflow/contrib/batching/util:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/boosted_trees:all_files",
"//tensorflow/contrib/boosted_trees/estimator_batch:all_files",
@@ -537,7 +543,7 @@ filegroup(
"//tensorflow/contrib/periodic_resample:all_files",
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/py2tf:all_files",
- "//tensorflow/contrib/py2tf/convert:all_files",
+ "//tensorflow/contrib/py2tf/converters:all_files",
"//tensorflow/contrib/py2tf/pyct:all_files",
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
"//tensorflow/contrib/quantize:all_files",
@@ -568,6 +574,7 @@ filegroup(
"//tensorflow/contrib/tensor_forest/proto:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/tensorboard/db:all_files",
+ "//tensorflow/contrib/tensorrt:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof:all_files",
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index f258bcd956..c46cb32aa4 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -26,6 +26,18 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
+filegroup(
+ name = "srcs",
+ srcs = glob(
+ [
+ "*.cc",
+ "*.h",
+ ],
+ exclude = ["*test*"],
+ ),
+ visibility = ["//visibility:public"],
+)
+
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 6fc75a98f1..3c7f041b39 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -927,6 +927,7 @@ int TF_DeviceListCount(const TF_DeviceList* list) {
status->status = InvalidArgument("index out of bounds"); \
return err_val; \
} \
+ status->status = Status::OK(); \
return list->response[index].accessor; \
}
@@ -1469,7 +1470,13 @@ int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
}
int TF_OperationNumControlInputs(TF_Operation* oper) {
- return oper->node.in_edges().size() - oper->node.num_inputs();
+ int count = 0;
+ for (const auto* edge : oper->node.in_edges()) {
+ if (edge->IsControlEdge() && !edge->src()->IsSource()) {
+ ++count;
+ }
+ }
+ return count;
}
int TF_OperationGetControlInputs(TF_Operation* oper,
@@ -1477,7 +1484,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper,
int max_control_inputs) {
int count = 0;
for (const auto* edge : oper->node.in_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->src()->IsSource()) {
if (count < max_control_inputs) {
control_inputs[count] = ToOperation(edge->src());
}
@@ -1490,7 +1497,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper,
int TF_OperationNumControlOutputs(TF_Operation* oper) {
int count = 0;
for (const auto* edge : oper->node.out_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
++count;
}
}
@@ -1502,7 +1509,7 @@ int TF_OperationGetControlOutputs(TF_Operation* oper,
int max_control_outputs) {
int count = 0;
for (const auto* edge : oper->node.out_edges()) {
- if (edge->IsControlEdge()) {
+ if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
if (count < max_control_outputs) {
control_outputs[count] = ToOperation(edge->dst());
}
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index df697e16d3..01954eb235 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -575,7 +575,7 @@ TEST(CAPI, ImportGraphDef) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
- // Create a graph with two nodes: x and 3
+ // Create a simple graph.
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
@@ -586,7 +586,7 @@ TEST(CAPI, ImportGraphDef) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
- // Export to a GraphDef
+ // Export to a GraphDef.
TF_Buffer* graph_def = TF_NewBuffer();
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@@ -606,6 +606,31 @@ TEST(CAPI, ImportGraphDef) {
ASSERT_TRUE(feed != nullptr);
ASSERT_TRUE(neg != nullptr);
+ // Test basic structure of the imported graph.
+ EXPECT_EQ(0, TF_OperationNumInputs(scalar));
+ EXPECT_EQ(0, TF_OperationNumInputs(feed));
+ ASSERT_EQ(1, TF_OperationNumInputs(neg));
+ TF_Output neg_input = TF_OperationInput({neg, 0});
+ EXPECT_EQ(scalar, neg_input.oper);
+ EXPECT_EQ(0, neg_input.index);
+
+ // Test that we can't see control edges involving the source and sink nodes.
+ TF_Operation* control_ops[100];
+ EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
+
+ EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
+
+ EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
+ EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
+ EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
+ EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
+
// Import it again, with an input mapping, return outputs, and a return
// operation, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
@@ -629,7 +654,7 @@ TEST(CAPI, ImportGraphDef) {
ASSERT_TRUE(neg2 != nullptr);
// Check input mapping
- TF_Output neg_input = TF_OperationInput({neg, 0});
+ neg_input = TF_OperationInput({neg, 0});
EXPECT_EQ(scalar, neg_input.oper);
EXPECT_EQ(0, neg_input.index);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index 3429009a71..6acc2fec00 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_
+#ifndef TENSORFLOW_C_C_TEST_UTIL_H_
+#define TENSORFLOW_C_C_TEST_UTIL_H_
#include "tensorflow/c/c_api.h"
@@ -136,4 +136,4 @@ class CSession {
std::vector<TF_Operation*> targets_;
};
-#endif // THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_
+#endif // TENSORFLOW_C_C_TEST_UTIL_H_
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 04a415b909..a76c8f5ec0 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -118,6 +118,23 @@ void TFE_ContextClearCaches(TFE_Context* ctx) {
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
}
+void TFE_ContextSetThreadLocalDevicePlacementPolicy(
+ TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
+ tensorflow::mutex_lock ml(ctx->policy_map_mu);
+ ctx->thread_local_policies[std::this_thread::get_id()] = policy;
+}
+
+extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
+ TFE_Context* ctx) {
+ tensorflow::mutex_lock ml(ctx->policy_map_mu);
+ auto policy_map_it =
+ ctx->thread_local_policies.find(std::this_thread::get_id());
+ if (policy_map_it != ctx->thread_local_policies.end()) {
+ return policy_map_it->second;
+ }
+ return ctx->policy;
+}
+
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -435,10 +452,17 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device =
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
if (expected_device != actual_device) {
- switch (ctx->policy) {
- case TFE_DEVICE_PLACEMENT_EXPLICIT:
+ switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
+ case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up
// to python level.
+ if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
+ // Note: enabling silent copies of int32 tensors to match behavior
+ // of graph mode.
+ break;
+ }
+ TF_FALLTHROUGH_INTENDED;
+ case TFE_DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"Tensors on conflicting devices:"
" cannot compute ",
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 9b0fd037da..387de07894 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -61,14 +61,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
typedef enum TFE_ContextDevicePlacementPolicy {
- // The default: running operations with input tensors on the wrong device will
- // fail.
+ // Running operations with input tensors on the wrong device will fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the
// operation will be blocked till the copy completes.
TFE_DEVICE_PLACEMENT_SILENT = 2,
+ // Default placement policy which silently copies int32 tensors but not other
+ // dtypes.
+ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
@@ -93,6 +95,18 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
// ops.
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
+// Sets a thread-local device placement policy. After this call, other calls to
+// TFE_Execute in the same thread will use the device policy specified here
+// instead of the device policy used to construct the context. This has no
+// effect on the device policy used by other program threads.
+TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
+ TFE_Context*, TFE_ContextDevicePlacementPolicy);
+
+// Returns the device placement policy to be used by this context in the current
+// thread.
+TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
+TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
+
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 55a04d48ba..a6f76c732f 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <cstddef>
#include <memory>
#include <string>
+#include <thread>
#include <vector>
#include "tensorflow/c/c_api.h"
@@ -37,7 +38,8 @@ limitations under the License.
struct TFE_ContextOptions {
TF_SessionOptions session_options;
- TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
+ TFE_ContextDevicePlacementPolicy policy{
+ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
};
struct TFE_Context {
@@ -45,6 +47,12 @@ struct TFE_Context {
TFE_ContextDevicePlacementPolicy policy;
+ // Note: we cannot use C++11 thread_local here as there is no concept of a
+ // thread-local-object-local variable in C++11.
+ tensorflow::mutex policy_map_mu;
+ std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
+ thread_local_policies GUARDED_BY(policy_map_mu);
+
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::Rendezvous* rendezvous;
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 423a7e1ff7..18e7a64435 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -321,6 +321,55 @@ TEST(CAPI, TensorHandleSilentCopy) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
+TEST(CAPI, TensorHandleSilentCopyLocal) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts,
+ TFE_DEVICE_PLACEMENT_EXPLICIT);
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
+ TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ const int num_devices = TF_DeviceListCount(devices);
+
+ // Disable the test if no GPU is present.
+ if (num_devices > 1) {
+ const int device_to_use = 1;
+ const string name(TF_DeviceListName(devices, device_to_use, status.get()));
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_TensorHandle* hgpu =
+ TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
+ TFE_OpSetDevice(matmul, name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_DeleteOp(matmul);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
+ }
+
+ TF_DeleteDeviceList(devices);
+ TF_DeleteTensor(t);
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_DeleteContext(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+}
+
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index b51ef2b531..aa9d9e06b2 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
-#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
+#ifndef TENSORFLOW_C_PYTHON_API_H_
+#define TENSORFLOW_C_PYTHON_API_H_
#include "tensorflow/c/c_api.h"
@@ -39,4 +39,4 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_
+#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index ddcee3deee..c9ade5fb83 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -673,7 +673,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:tensorflow",
],
)
diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h
index 1b5f7dd923..c7256a7dc3 100644
--- a/tensorflow/cc/framework/cc_op_gen.h
+++ b/tensorflow/cc/framework/cc_op_gen.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
+#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
@@ -28,4 +28,4 @@ void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h
index 190b96f685..0fc5abb20c 100644
--- a/tensorflow/cc/framework/grad_op_registry.h
+++ b/tensorflow/cc/framework/grad_op_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
+#define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#include <unordered_map>
@@ -72,4 +72,4 @@ class GradOpRegistry {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h
index d055c60d09..1aa215a908 100644
--- a/tensorflow/cc/framework/gradient_checker.h
+++ b/tensorflow/cc/framework/gradient_checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
+#define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -60,4 +60,4 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_
diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h
index 717f6f0636..0a377ad56d 100644
--- a/tensorflow/cc/framework/gradients.h
+++ b/tensorflow/cc/framework/gradients.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
+#define TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -49,4 +49,4 @@ Output NoGradient();
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_
diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h
index 8d4154220c..a085e1d6e2 100644
--- a/tensorflow/cc/framework/ops.h
+++ b/tensorflow/cc/framework/ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_
+#define TENSORFLOW_CC_FRAMEWORK_OPS_H_
#include <type_traits>
@@ -296,4 +296,4 @@ class InputList {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index 0225ac0472..30c32bd44b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
+#define TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
#include <memory>
#include <string>
@@ -242,4 +242,4 @@ struct CompositeOpScopes {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 968c366550..8efcfed20d 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
+#define TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
#include "tensorflow/cc/framework/scope.h"
@@ -117,4 +117,4 @@ class Scope::Impl {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h
index a3e19870ec..7ad6fb4a67 100644
--- a/tensorflow/cc/framework/testutil.h
+++ b/tensorflow/cc/framework/testutil.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
+#define TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -44,4 +44,4 @@ void GetTensor(const Scope& scope, const std::vector<Output>& assign_vars,
} // namespace test
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_
diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h
index 8f592accc9..cb4e579c85 100644
--- a/tensorflow/cc/framework/while_gradients.h
+++ b/tensorflow/cc/framework/while_gradients.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+#ifndef TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+#define TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -37,4 +37,4 @@ Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+#endif // TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
diff --git a/tensorflow/cc/gradients/grad_testutil.h b/tensorflow/cc/gradients/grad_testutil.h
index d31f412754..70c81f1a73 100644
--- a/tensorflow/cc/gradients/grad_testutil.h
+++ b/tensorflow/cc/gradients/grad_testutil.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
+#ifndef TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
+#define TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -32,4 +32,4 @@ Status CallGradFunction(const Scope& scope, const Operation& op,
} // namespace test
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
+#endif // TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_
diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h
index d11fda475b..424a683665 100644
--- a/tensorflow/cc/ops/const_op.h
+++ b/tensorflow/cc/ops/const_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
+#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
+#define TENSORFLOW_CC_OPS_CONST_OP_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -82,4 +82,4 @@ std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
} // namespace ops
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
+#endif // TENSORFLOW_CC_OPS_CONST_OP_H_
diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h
index 0c021f0b3a..98f53010ec 100644
--- a/tensorflow/cc/ops/standard_ops.h
+++ b/tensorflow/cc/ops/standard_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/candidate_sampling_ops.h"
@@ -37,4 +37,4 @@ limitations under the License.
#include "tensorflow/cc/ops/training_ops.h"
#include "tensorflow/cc/ops/user_ops.h"
-#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
+#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_
diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h
index a04476056a..727237b5c7 100644
--- a/tensorflow/cc/ops/while_loop.h
+++ b/tensorflow/cc/ops/while_loop.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_
-#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_
+#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_
+#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -71,4 +71,4 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
} // namespace ops
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_
+#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_
diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h
index e1ce315d3c..6077c45c58 100644
--- a/tensorflow/cc/profiler/profiler.h
+++ b/tensorflow/cc/profiler/profiler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_
-#define THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_
+#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_
+#define TENSORFLOW_CC_PROFILER_PROFILER_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -94,4 +94,4 @@ class Profiler {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_
+#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_
diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h
index c940df8a87..645a3f101d 100644
--- a/tensorflow/cc/saved_model/constants.h
+++ b/tensorflow/cc/saved_model/constants.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
+#ifndef TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
+#define TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
namespace tensorflow {
@@ -47,4 +47,4 @@ constexpr char kSavedModelVariablesFilename[] = "variables";
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
+#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h
index 3d634dd515..a8e098fa54 100644
--- a/tensorflow/cc/saved_model/loader.h
+++ b/tensorflow/cc/saved_model/loader.h
@@ -15,8 +15,8 @@ limitations under the License.
/// SavedModel loading functions and SavedModelBundle struct.
-#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
-#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
+#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
+#define TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
#include <string>
#include <unordered_set>
@@ -61,4 +61,4 @@ bool MaybeSavedModelDirectory(const string& export_dir);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
+#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_H_
diff --git a/tensorflow/cc/saved_model/signature_constants.h b/tensorflow/cc/saved_model/signature_constants.h
index b2d39bd55b..7d8c07f5cf 100644
--- a/tensorflow/cc/saved_model/signature_constants.h
+++ b/tensorflow/cc/saved_model/signature_constants.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
+#ifndef TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
+#define TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
namespace tensorflow {
@@ -66,4 +66,4 @@ static constexpr char kRegressOutputs[] = "outputs";
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
+#endif // TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_
diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h
index b71cb263ca..68a090e0c4 100644
--- a/tensorflow/cc/saved_model/tag_constants.h
+++ b/tensorflow/cc/saved_model/tag_constants.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
-#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
+#ifndef TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
+#define TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
namespace tensorflow {
@@ -32,4 +32,4 @@ constexpr char kSavedModelTagTrain[] = "train";
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
+#endif // TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_
diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h
index bd5e0516c8..b10f29805a 100644
--- a/tensorflow/cc/tools/freeze_saved_model.h
+++ b/tensorflow/cc/tools/freeze_saved_model.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
-#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+#ifndef TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+#define TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
#include <unordered_set>
@@ -40,4 +40,4 @@ Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+#endif // TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
index 57244a4f0a..52a81a5028 100644
--- a/tensorflow/cc/tools/freeze_saved_model_test.cc
+++ b/tensorflow/cc/tools/freeze_saved_model_test.cc
@@ -71,7 +71,7 @@ class FreezeTest : public ::testing::Test {
return Status::OK();
}
- // Adds `graph_def` to `saved_model_bundle` and intializes a session with
+ // Adds `graph_def` to `saved_model_bundle` and initializes a session with
// `init_node`.
Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
const string& init_node,
diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h
index 0e01b19cd9..7168b77525 100644
--- a/tensorflow/cc/training/coordinator.h
+++ b/tensorflow/cc/training/coordinator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
+#ifndef TENSORFLOW_CC_TRAINING_COORDINATOR_H_
+#define TENSORFLOW_CC_TRAINING_COORDINATOR_H_
#include <atomic>
#include <memory>
@@ -128,4 +128,4 @@ class Coordinator {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_
+#endif // TENSORFLOW_CC_TRAINING_COORDINATOR_H_
diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h
index 2d34500323..21189b4b04 100644
--- a/tensorflow/cc/training/queue_runner.h
+++ b/tensorflow/cc/training/queue_runner.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
-#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
+#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
+#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
#include <memory>
#include <string>
@@ -137,4 +137,4 @@ class QueueRunner : public RunnerInterface {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
+#endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index f7c6cd293a..314f5506b1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -403,11 +403,6 @@ tf_xla_py_test(
disabled_backends = [
"gpu",
],
- tags = [
- "manual",
- "no_oss",
- "notap",
- ],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 65706b35d6..c95fb1c515 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -43,7 +43,7 @@ class BinaryOpsTest(XLATestCase):
output = op(pa, pb)
result = session.run(output, {pa: a, pb: b})
if equality_test is None:
- equality_test = self.assertAllClose
+ equality_test = self.assertAllCloseAccordingToType
equality_test(result, expected, rtol=1e-3)
def _testSymmetricBinary(self, op, a, b, expected, equality_test=None):
@@ -54,14 +54,20 @@ class BinaryOpsTest(XLATestCase):
"""Tests closeness of two lists of floats."""
self.assertEqual(len(result), len(expected))
for i in range(len(result)):
- self.assertAllClose(result[i], expected[i], rtol)
+ self.assertAllCloseAccordingToType(result[i], expected[i], rtol)
def testFloatOps(self):
for dtype in self.float_types:
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ a = -1.01
+ b = 4.1
+ else:
+ a = -1.001
+ b = 4.01
self._testBinary(
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
- np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype),
- np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype),
+ np.array([[[[-1, 2.00009999], [-3, b]]]], dtype=dtype),
+ np.array([[[[a, 2], [-3.00009, 4]]]], dtype=dtype),
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
self._testBinary(
@@ -768,15 +774,15 @@ class BinaryOpsTest(XLATestCase):
def DISABLED_testSparseMatMul(self):
# Binary wrappers for sparse_matmul with different hints
def SparseMatmulWrapperTF(a, b):
- return tf.sparse_matmul(a, b, a_is_sparse=True)
+ return math_ops.sparse_matmul(a, b, a_is_sparse=True)
def SparseMatmulWrapperFT(a, b):
- return tf.sparse_matmul(a, b, b_is_sparse=True)
+ return math_ops.sparse_matmul(a, b, b_is_sparse=True)
def SparseMatmulWrapperTT(a, b):
- return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
+ return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
- self._testMatMul(tf.sparse_matmul)
+ self._testMatMul(math_ops.sparse_matmul)
self._testMatMul(SparseMatmulWrapperTF)
self._testMatMul(SparseMatmulWrapperFT)
self._testMatMul(SparseMatmulWrapperTT)
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index e84b790037..538fa8e8e5 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -65,7 +65,7 @@ class RGBToHSVTest(XLATestCase):
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1)
self.assertAllClose(batch2, join2)
- self.assertAllClose(batch2, inp)
+ self.assertAllCloseAccordingToType(batch2, inp, bfloat16_atol=0.03)
def testRGBToHSVRoundTrip(self):
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
@@ -77,21 +77,25 @@ class RGBToHSVTest(XLATestCase):
hsv = image_ops.rgb_to_hsv(placeholder)
rgb = image_ops.hsv_to_rgb(hsv)
rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np})
- self.assertAllClose(rgb_tf, rgb_np)
+ self.assertAllCloseAccordingToType(rgb_tf, rgb_np, bfloat16_atol=0.03)
def testRGBToHSVNumpy(self):
"""Tests the RGB to HSV conversion matches a reference implementation."""
for nptype in self.float_types:
rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype)
rgb_np = rgb_flat.reshape(4, 4, 4, 3)
- hsv_np = np.array([colorsys.rgb_to_hsv(r, g, b) for r, g, b in rgb_flat])
+ hsv_np = np.array([
+ colorsys.rgb_to_hsv(
+ r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
+ for r, g, b in rgb_flat
+ ])
hsv_np = hsv_np.reshape(4, 4, 4, 3)
with self.test_session():
placeholder = array_ops.placeholder(nptype)
with self.test_scope():
hsv_op = image_ops.rgb_to_hsv(placeholder)
hsv_tf = hsv_op.eval(feed_dict={placeholder: rgb_np})
- self.assertAllClose(hsv_tf, hsv_np)
+ self.assertAllCloseAccordingToType(hsv_tf, hsv_np)
class AdjustContrastTest(XLATestCase):
@@ -427,7 +431,8 @@ class ResizeBilinearTest(XLATestCase):
np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype),
align_corners=True)
out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]})
- self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out)
+ self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis],
+ out)
def testAlignCorners1x2To3x2(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 0a6fe04d3c..8e4b8a3833 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -67,8 +67,10 @@ class UnaryOpsTest(XLATestCase):
output = op(pinp)
result = session.run(output, {pinp: inp})
if equality_test is None:
- equality_test = self.assertAllCloseAccordingToType
- equality_test(result, expected, rtol=rtol, atol=atol)
+ self.assertAllCloseAccordingToType(
+ result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
+ else:
+ equality_test(result, expected, rtol=rtol, atol=atol)
def ListsAreClose(self, result, expected, rtol, atol):
"""Tests closeness of two lists of floats."""
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 8062f0c03c..02215b5112 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -144,7 +145,9 @@ Status GraphCompiler::Compile() {
} else {
device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
Status s = op_context.status();
- TF_RETURN_IF_ERROR(s);
+ if (!s.ok()) {
+ return AttachDef(s, n->def());
+ }
}
// Set up outputs. Also check if outputs from the previous computation is
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h
index 575086e118..ca57be3d47 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_util.h
+++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h
@@ -31,4 +31,4 @@ Status TensorShapeToConstant(const TensorShape& input_shape,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 79da701fd2..672e19bd93 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -29,7 +29,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
- hlo_profile_printer_(static_data.hlo_profile_printer) {
+ hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
// Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index e0ae3ed9a8..48a8c083ca 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -26,7 +26,7 @@ limitations under the License.
// never use this functionality.
namespace xla {
class ProgramShape;
-class HloProfilePrinter;
+class HloProfilePrinterData;
}
namespace tensorflow {
@@ -77,12 +77,14 @@ class XlaCompiledCpuFunction {
// [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr;
- // [Optional] Profile printer. Null if profiling is disabled.
- const xla::HloProfilePrinter* hlo_profile_printer = nullptr;
+ // [Optional] Profile printer data. Null if profiling is disabled.
+ const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
// [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is
- // disabled.
+ // disabled. This information is already present in
+ // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
+ // declared so we don't have access to that information here.
int64 profile_counters_size = 0;
};
@@ -205,10 +207,12 @@ class XlaCompiledCpuFunction {
// program shape isn't available.
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
- bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
- const xla::HloProfilePrinter& hlo_profile_printer() const {
+ bool hlo_profiling_enabled() const {
+ return hlo_profile_printer_data_ != nullptr;
+ }
+ const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
assert(hlo_profiling_enabled());
- return *hlo_profile_printer_;
+ return *hlo_profile_printer_data_;
}
private:
@@ -234,7 +238,7 @@ class XlaCompiledCpuFunction {
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
const xla::ProgramShape* program_shape_ = nullptr;
- const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr;
+ const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 93aae8485d..7ebe4b75bc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -227,6 +227,42 @@ TEST_F(XlaCompilerTest, Simple) {
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
}
+TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
+ // Builds a graph that adds reshapes a tensor, but with the shape not
+ // statically known.
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
+ auto c = ops::Reshape(scope.WithOpName("C"), a, b);
+ auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[1].kind = XlaCompiler::Argument::kParameter;
+ args[1].type = DT_INT32;
+ args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ Status status =
+ compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
+ std::move(graph), args, &result);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(
+ StringPiece(status.error_message()).contains("depends on a parameter"))
+ << status.error_message();
+ EXPECT_TRUE(
+ StringPiece(status.error_message()).contains("[[Node: C = Reshape"))
+ << status.error_message();
+}
+
// Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest, ConstantOutputs) {
// Builds a graph with one compile-time constant output and one data-dependent
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 584417bc72..1fe6e69ff2 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -182,10 +182,10 @@ XlaJitCompiledCpuFunction::Compile(
jit->static_data_.program_shape = jit->program_shape_.get();
if (cpu_executable->hlo_profiling_enabled()) {
- jit->static_data_.hlo_profile_printer =
- &cpu_executable->hlo_profile_printer();
+ jit->static_data_.hlo_profile_printer_data =
+ &cpu_executable->hlo_profile_printer_data();
jit->static_data_.profile_counters_size =
- cpu_executable->hlo_profile_printer().profile_counters_size();
+ cpu_executable->hlo_profile_printer_data().profile_counters_size();
}
return std::move(jit_unique_ptr);
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index a9dcb662b3..ee0aed672e 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -118,13 +118,36 @@ Status XlaOpKernelContext::ConstantInputReshaped(
std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
+ xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
+ if (!is_constant.ok()) {
+ Status status = is_constant.status();
+ errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
+ context_->op_kernel().type_string(),
+ " operator as a compile-time constant.");
+ return status;
+ }
+
+ if (!is_constant.ValueOrDie()) {
+ return errors::InvalidArgument(
+ "Input ", index, " to ", context_->op_kernel().type_string(),
+ " operator must be a compile-time constant.\n"
+ "\n"
+ "XLA compilation requires that operator arguments that represent "
+ "shapes or dimensions be evaluated to concrete values at compile time. "
+ "This error means that a shape or dimension argument could not be "
+ "evaluated at compile time, usually because the value of the argument "
+ "depends on a parameter to the computation, on a variable, or on a "
+ "stateful operation such as a random number generator.");
+ }
+
// Ask the XLA compiler to evaluate the data handle to a literal.
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
builder()->ComputeConstant(handle, &layout);
if (!computed.ok()) {
- return errors::InvalidArgument(
- "Error evaluating ", context_->op_kernel().name(), " input ", index,
- ": ", computed.status().error_message());
+ return errors::Internal("Error evaluating ", context_->op_kernel().name(),
+ " input ", index,
+ "as a compile-time constant.\nError: ",
+ computed.status().error_message());
}
*constant_literal = std::move(*computed.ValueOrDie());
@@ -389,10 +412,20 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
-void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
-void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
+void XlaOpKernelContext::CtxFailure(const Status& s) {
+ context_->CtxFailure(s);
+}
+void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
context_->CtxFailureWithWarning(s);
}
+void XlaOpKernelContext::CtxFailure(const char* file, int line,
+ const Status& s) {
+ context_->CtxFailure(file, line, s);
+}
+void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
+ const Status& s) {
+ context_->CtxFailureWithWarning(file, line, s);
+}
const xla::Computation* XlaOpKernelContext::GetOrCreateMax(
const DataType type) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index f1ae81a5aa..6d3b6db228 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -173,8 +173,10 @@ class XlaOpKernelContext {
const xla::ComputationDataHandle& handle);
// Helper routines for the OP_REQUIRES macros
- void CtxFailure(Status s);
- void CtxFailureWithWarning(Status s);
+ void CtxFailure(const Status& s);
+ void CtxFailureWithWarning(const Status& s);
+ void CtxFailure(const char* file, int line, const Status& s);
+ void CtxFailureWithWarning(const char* file, int line, const Status& s);
// If this kernel invocation is within a function execution,
// call_frame() returns the call frame for the function call.
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 438f1443f1..c22fd37129 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -182,6 +182,7 @@ cc_library(
deps = [
":status",
":status_macros",
+ ":statusor",
":types",
":xla_data_proto",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index d6b4ebfc39..952109dde2 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -98,6 +98,7 @@ cc_library(
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:source_map_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:support",
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 18ab24d9e6..ea4cdb7667 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -67,7 +67,7 @@ class ComputationBuilder {
// OpMetadata is often applied to a series of XLA HLO instructions. As a
// result, OpMetadata is set on the Computation Builder. All subsequent
// instructions generated via this Computation Builder will have the same
- // OpMetadata attached until a call to ClearOpMetdata.
+ // OpMetadata attached until a call to ClearOpMetadata.
void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
// Clears the HloMetadata state.
@@ -715,7 +715,7 @@ class ComputationBuilder {
ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
// Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on parameters with higher index then
+ // constant does not depend on parameters with index greater than or equal to
// `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`.
// Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a
// compile-time constant without evaluating the computation.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 523169fdd2..fbeedfcecd 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -21,10 +21,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace se = ::perftools::gputools;
+using xla::source_map_util::InvalidParameterArgument;
+
namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
@@ -79,9 +82,10 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
for (int i = 0; i < arguments.size(); ++i) {
if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
- return InvalidArgument(
- "argument does not match shape or layout of computation parameter "
- "%d: expected %s, got %s",
+ return InvalidParameterArgument(
+ executable_.get(), i,
+ "Argument does not match shape or layout of computation parameter "
+ "%d: want %s, got %s",
i,
ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
.c_str(),
diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h
index 562da78e83..a8ca27ec8d 100644
--- a/tensorflow/compiler/xla/execution_options_util.h
+++ b/tensorflow/compiler/xla/execution_options_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
#include "tensorflow/compiler/xla/xla.pb.h"
@@ -26,4 +26,4 @@ ExecutionOptions CreateDefaultExecutionOptions();
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
+#endif // TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_
diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h
index a39999705e..a8bb8c7a7e 100644
--- a/tensorflow/compiler/xla/iterator_util.h
+++ b/tensorflow/compiler/xla/iterator_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
#include <iterator>
#include <utility>
@@ -95,4 +95,4 @@ UnwrappingIterator<NestedIter> MakeUnwrappingIterator(NestedIter iter) {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index e88bffd0ba..fe3a4d2f6d 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -224,6 +224,11 @@ void AllocateFlags() {
"xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(),
"Dump compilation artifacts as proto binary into this directory."),
tensorflow::Flag(
+ "xla_dump_prepass_hlo_proto_to",
+ flag_values->mutable_xla_dump_prepass_hlo_proto_to(),
+ "Dump compilation artifacts, before hlo passes are executed, as "
+ "proto binary into this directory."),
+ tensorflow::Flag(
"xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
flag_values->xla_test_all_output_layouts(),
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h
index d0ef8e66ab..b53157f59c 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
#include <vector>
@@ -35,4 +35,4 @@ xla::DebugOptions GetDebugOptionsFromFlags();
} // namespace legacy_flags
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
index 0c238e6a5d..e9cf435d83 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -148,4 +148,4 @@ inline bool parse_xla_reduce_precision_option(
} // namespace legacy_flags
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 7f0201e74a..89279b659c 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -830,6 +830,16 @@ std::unique_ptr<Literal> Literal::Slice(
result_literal->Set<float>(indices, value);
});
return result_literal;
+ case C64:
+ result_literal->EachCell<complex64>(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, complex64 /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ complex64 value = Get<complex64>(new_indices);
+ result_literal->Set<complex64>(indices, value);
+ });
+ return result_literal;
case S32:
result_literal->EachCell<int32>(
[&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
index 50659c1240..0ad0b91330 100644
--- a/tensorflow/compiler/xla/map_util.h
+++ b/tensorflow/compiler/xla/map_util.h
@@ -16,6 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
+#include <functional>
+#include <sstream>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -44,6 +49,22 @@ typename Collection::value_type::second_type& FindOrDie(
return it->second;
}
+// Like FindOrDie but returns an error instead of dying if `key` is not in
+// `container`.
+template <class Collection>
+StatusOr<
+ std::reference_wrapper<const typename Collection::value_type::second_type>>
+MaybeFind(const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ std::ostringstream os;
+ os << key;
+ return NotFound("key not found: %s", os.str().c_str());
+ }
+ return {it->second};
+}
+
// Inserts the key-value pair into the collection. Dies if key was already
// present.
template <class Collection>
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 2bce56b7bd..143c9a2366 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -20,79 +20,6 @@ limitations under the License.
namespace xla {
namespace primitive_util {
-template <>
-PrimitiveType NativeToPrimitiveType<bool>() {
- return PRED;
-}
-
-// Unsigned integer
-template <>
-PrimitiveType NativeToPrimitiveType<uint8>() {
- return U8;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint16>() {
- return U16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint32>() {
- return U32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint64>() {
- return U64;
-}
-
-// Signed integer
-template <>
-PrimitiveType NativeToPrimitiveType<int8>() {
- return S8;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int16>() {
- return S16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int32>() {
- return S32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int64>() {
- return S64;
-}
-
-// Floating point
-template <>
-PrimitiveType NativeToPrimitiveType<float>() {
- return F32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<double>() {
- return F64;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>() {
- return BF16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<half>() {
- return F16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<complex64>() {
- return C64;
-}
-
bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64 || type == BF16;
}
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index cb4583d198..b26a10ade6 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -47,49 +47,81 @@ PrimitiveType NativeToPrimitiveType() {
}
// Declarations of specializations for each native type which correspond to a
-// XLA primitive type.
+// XLA primitive type. As an optimization, these are declared inline in the
+// header.
template <>
-PrimitiveType NativeToPrimitiveType<bool>();
+inline PrimitiveType NativeToPrimitiveType<bool>() {
+ return PRED;
+}
// Unsigned integer
template <>
-PrimitiveType NativeToPrimitiveType<uint8>();
+inline PrimitiveType NativeToPrimitiveType<uint8>() {
+ return U8;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint16>();
+inline PrimitiveType NativeToPrimitiveType<uint16>() {
+ return U16;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint32>();
+inline PrimitiveType NativeToPrimitiveType<uint32>() {
+ return U32;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint64>();
+inline PrimitiveType NativeToPrimitiveType<uint64>() {
+ return U64;
+}
// Signed integer
template <>
-PrimitiveType NativeToPrimitiveType<int8>();
+inline PrimitiveType NativeToPrimitiveType<int8>() {
+ return S8;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int16>();
+inline PrimitiveType NativeToPrimitiveType<int16>() {
+ return S16;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int32>();
+inline PrimitiveType NativeToPrimitiveType<int32>() {
+ return S32;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int64>();
+inline PrimitiveType NativeToPrimitiveType<int64>() {
+ return S64;
+}
// Floating point
template <>
-PrimitiveType NativeToPrimitiveType<float>();
+inline PrimitiveType NativeToPrimitiveType<float>() {
+ return F32;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<double>();
+inline PrimitiveType NativeToPrimitiveType<double>() {
+ return F64;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<half>();
+inline PrimitiveType NativeToPrimitiveType<half>() {
+ return F16;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>();
+inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
+ return BF16;
+}
// Complex
template <>
-PrimitiveType NativeToPrimitiveType<complex64>();
+inline PrimitiveType NativeToPrimitiveType<complex64>() {
+ return C64;
+}
bool IsFloatingPointType(PrimitiveType type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 5455adafcd..66ace613a0 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -36,15 +36,22 @@ from tensorflow.compiler.xla.python import pywrap_xla as c_api
# pylint: disable=invalid-name
-OpMetadata = collections.namedtuple(
- 'OpMetadata',
- [
- 'op_type',
- 'op_name',
- 'source_file',
- 'source_line',
- ],
-)
+_OP_METADATA_FIELDS = [
+ 'op_type',
+ 'op_name',
+ 'source_file',
+ 'source_line',
+]
+OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
+
+
+def OpMetadataToProto(pyobj):
+ proto = xla_data_pb2.OpMetadata()
+ for field in _OP_METADATA_FIELDS:
+ attr = getattr(pyobj, field)
+ if attr is not None:
+ setattr(proto, field, attr)
+ return proto
def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
@@ -555,22 +562,12 @@ class ComputationBuilder(object):
A ComputationDataHandle representing the added pad op.
"""
if not isinstance(padding_config, xla_data_pb2.PaddingConfig):
- padding_config = self._GetPaddingConfigFromTriples(padding_config)
+ padding_config = GetPaddingConfigFromTriples(padding_config)
return _wrap_data_handle(
self._client.Pad(_unwrap_data_handle(operand),
_unwrap_data_handle(padding_value),
padding_config))
- def _GetPaddingConfigFromTriples(self, triples):
- """Create PaddingConfig proto from list of triples of integers."""
- padding_config = xla_data_pb2.PaddingConfig()
- for lo, hi, interior in triples:
- dimension = padding_config.dimensions.add()
- dimension.edge_padding_low = lo
- dimension.edge_padding_high = hi
- dimension.interior_padding = interior
- return padding_config
-
def Reshape(self, operand, dimensions, new_sizes):
"""Reshape op."""
return _wrap_data_handle(
@@ -997,3 +994,14 @@ def get_replica_count():
yet or not.
"""
return c_api.GetReplicaCount()
+
+
+def GetPaddingConfigFromTriples(triples):
+ """Create PaddingConfig proto from list of triples of integers."""
+ padding_config = xla_data_pb2.PaddingConfig()
+ for lo, hi, interior in triples:
+ dimension = padding_config.dimensions.add()
+ dimension.edge_padding_low = lo
+ dimension.edge_padding_high = hi
+ dimension.interior_padding = interior
+ return padding_config
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 71341c6f1e..469acc330c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -29,6 +29,11 @@ xla_proto_library(
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
)
+xla_proto_library(
+ name = "hlo_profile_printer_data",
+ srcs = ["hlo_profile_printer_data.proto"],
+)
+
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
@@ -452,8 +457,10 @@ cc_library(
":hlo_evaluator",
":hlo_execution_profile",
":hlo_module_config",
+ ":hlo_proto_util",
":platform_util",
":session_proto",
+ ":source_map_util",
":transfer_manager",
":user_computation",
":versioned_computation_handle",
@@ -905,6 +912,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1084,6 +1092,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
@@ -1940,6 +1949,16 @@ cc_library(
],
)
+tf_cc_test(
+ name = "hlo_element_type_converter_test",
+ srcs = ["hlo_element_type_converter_test.cc"],
+ deps = [
+ ":hlo_element_type_converter",
+ ":hlo_matchers",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ ],
+)
+
cc_library(
name = "device_memory_allocator",
srcs = ["device_memory_allocator.cc"],
@@ -2252,6 +2271,7 @@ cc_library(
srcs = ["hlo_profile_printer.cc"],
hdrs = ["hlo_profile_printer.h"],
deps = [
+ ":hlo_profile_printer_data",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
],
@@ -2329,6 +2349,18 @@ tf_cc_test(
],
)
+cc_library(
+ name = "source_map_util",
+ srcs = ["source_map_util.cc"],
+ hdrs = ["source_map_util.h"],
+ deps = [
+ ":executable",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 90a3f0b674..ba82e822b2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1741,6 +1741,63 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
}
}
+ // If the pad puts a single non-identity value in each window that we're
+ // reducing, then this is a broadcast.
+ HloInstruction* pad_operand = operand->mutable_operand(0);
+ auto is_effective_broadcast = [&] {
+ if (window_util::HasStride(window)) {
+ VLOG(10) << "Window has stride.";
+ return false;
+ }
+ if (!window_util::HasSymmetricPadding(pad_config)) {
+ VLOG(10) << "Window has uneven padding.";
+ return false;
+ }
+ for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
+ const auto& pad_dimension = pad_config.dimensions(i);
+ if ((pad_dimension.edge_padding_low() != 0 ||
+ pad_dimension.edge_padding_high() != 0) &&
+ pad_operand->shape().dimensions(i) != 1) {
+ VLOG(10) << "Found non-trivial dimension being padded: " << i;
+ return false;
+ }
+ }
+ VLOG(10) << "Found to be padding trivial dimensions only.";
+
+ for (int64 i = 0; i < window.dimensions_size(); ++i) {
+ const auto& pad_dimension = pad_config.dimensions(i);
+ const WindowDimension& window_dimension = window.dimensions(i);
+ bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
+ pad_dimension.edge_padding_high() != 0);
+ if (dimension_has_padding &&
+ window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
+ VLOG(10) << "Found window did not cover single unpadded element in "
+ "dimension: "
+ << i;
+ return false;
+ }
+ if (pad_operand->shape().dimensions(i) != 1 &&
+ window_dimension.size() != 1) {
+ VLOG(10) << "Found window covers more than one element in non-trivial "
+ "dimension: "
+ << i;
+ return false;
+ }
+ }
+ VLOG(10) << "Found window covers a single unpadded element.";
+ return true;
+ };
+ if (is_effective_broadcast()) {
+ VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast.";
+ auto fadd = [this](std::unique_ptr<HloInstruction> x) {
+ return computation_->AddInstruction(std::move(x));
+ };
+ return ReplaceWithNewInstruction(
+ reduce_window, HloInstruction::CreateBroadcastSequence(
+ /*output_shape=*/reduce_window->shape(),
+ /*operand=*/pad_operand, fadd));
+ }
+
// Carry out the folding of the pad into reduce_window.
VLOG(10) << "Folding pad into reduce-window.";
Window new_window = window;
@@ -1758,7 +1815,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return ReplaceWithNewInstruction(
reduce_window, HloInstruction::CreateReduceWindow(
/*shape=*/reduce_window->shape(),
- /*operand=*/operand->mutable_operand(0),
+ /*operand=*/pad_operand,
/*init_value=*/reduce_window->mutable_operand(1),
/*window=*/new_window,
/*reduce_computation=*/function));
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index e7c4dfb0a1..e43ea50af4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -2495,6 +2496,144 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
op::DynamicSlice(op::Parameter(), op::Parameter()));
}
+struct PadReduceWindowEffectiveBroadcastCase {
+ std::vector<int64> input_spatials;
+ std::vector<int64> symmetric_pad_spatials;
+ std::vector<int64> reduce_window_spatials;
+ // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
+ //
+ // This doesn't test any different functionality but is useful for making sure
+ // kBroadcast nodes are well formed.
+ bool prepend_a;
+ bool should_become_broadcast;
+
+ string ToTestCaseName() const {
+ return tensorflow::strings::StrCat(
+ tensorflow::str_util::Join(input_spatials, ","), ";",
+ tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
+ tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
+ ";", should_become_broadcast);
+ }
+};
+
+void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
+ *os << c.ToTestCaseName();
+}
+
+class PadReduceWindowEffectiveBroadcastTest
+ : public AlgebraicSimplifierTest,
+ public ::testing::WithParamInterface<
+ PadReduceWindowEffectiveBroadcastCase> {};
+
+TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
+ const auto& param = GetParam();
+
+ // a and b are parallel bounds we can either turn into a B F S0 S1 or
+ // `B S0 S1 F` kind of pattern.
+ auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
+ int64 a, int64 b) {
+ std::vector<int64> result;
+ if (param.prepend_a) {
+ result.push_back(a);
+ }
+ for (int64 s : spatials) {
+ result.push_back(s);
+ }
+ if (!param.prepend_a) {
+ result.push_back(a);
+ }
+ result.push_back(b);
+ return result;
+ };
+
+ HloComputation::Builder builder(TestName());
+ const Shape input_shape = ShapeUtil::MakeShape(
+ F32, decorate_spatials(param.input_spatials, 128, 2048));
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_shape, "input"));
+
+ PaddingConfig padding = window_util::MakeSymmetricPadding(
+ decorate_spatials(param.symmetric_pad_spatials, 0, 0));
+ HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(
+ F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)),
+ input,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ padding));
+
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module->AddEmbeddedComputation(builder.Build());
+ }
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ const Shape output_shape,
+ ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}),
+ padding));
+ Window window = window_util::MakeWindow(
+ decorate_spatials(param.reduce_window_spatials, 1, 1));
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ output_shape, pad, zero, window, add_computation));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get()));
+ ASSERT_TRUE(run_successful);
+
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
+
+ if (param.should_become_broadcast) {
+ EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_));
+ } else {
+ EXPECT_THAT(computation->root_instruction(),
+ op::ReduceWindow(::testing::_, zero));
+ }
+}
+
+const std::vector<PadReduceWindowEffectiveBroadcastCase>&
+PadReduceWindowEffectiveBroadcastCases() {
+ static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
+ {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
+ /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
+ /*should_become_broadcast=*/true}, //
+ {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
+ /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
+ /*should_become_broadcast=*/true}, //
+ {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
+ /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
+ /*should_become_broadcast=*/false}, //
+ {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
+ /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true,
+ /*should_become_broadcast=*/true}, //
+ {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
+ /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
+ /*should_become_broadcast=*/false}, //
+ {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
+ /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
+ /*should_become_broadcast=*/false}, //
+ };
+ return *cases;
+}
+
+INSTANTIATE_TEST_CASE_P(
+ PadReduceWindowEffectiveBroadcastInstantiation,
+ PadReduceWindowEffectiveBroadcastTest,
+ ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
+
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 33fe11b81d..d5594dc07c 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -846,14 +846,13 @@ Status BufferAssigner::AssignBuffersForComputation(
continue;
}
- if (is_thread_local || instruction->opcode() == HloOpcode::kCustomCall) {
- // Custom call operations never have reusable buffers. Also we do not
- // reuse thread-local buffers for now, because they are dynamically
- // allocated and their lifetimes are hard to compute.
+ if (is_thread_local) {
+ // We do not reuse thread-local buffers for now, because they are
+ // dynamically allocated and their lifetimes are hard to compute.
BufferAllocation* allocation = assignment->NewAllocation(
*buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
VLOG(3) << "New allocation #" << allocation->index()
- << " for thread-local/CustomCall: " << *buffer;
+ << " for thread-local: " << *buffer;
continue;
}
@@ -1359,6 +1358,43 @@ void BufferAssigner::BuildColocatedBufferSets(
index, points_to_analysis, &colocated_set);
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
+
+ // Add true_operand and conditional.true_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the true_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(1)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> true_set;
+ // Add conditional.true_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(1), index,
+ points_to_analysis, &true_set);
+ // Add conditional.true_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->true_computation()->parameter_instruction(0),
+ index, points_to_analysis, &true_set);
+ AddSetToColocatedBufferSets(true_set, colocated_buffer_sets);
+ });
+
+ // Add false_operand and conditional.false_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the false_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(2)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> false_set;
+ // Add conditional.false_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(2), index,
+ points_to_analysis, &false_set);
+ // Add conditional.false_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->false_computation()->parameter_instruction(
+ 0),
+ index, points_to_analysis, &false_set);
+ AddSetToColocatedBufferSets(false_set, colocated_buffer_sets);
+ });
}
}
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 9e96898d9b..dab73596e1 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -101,12 +101,13 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, instance.argument_layouts,
- &execution_options));
+ &execution_options, *user_computation));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle, *module_config,
/*include_unreachable_instructions=*/true));
+ TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index f0507982b3..33af77e1a8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -485,7 +485,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
@@ -505,8 +505,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
HloCostAnalysis cost_analysis(shape_size_bytes);
TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis));
- hlo_profile_printer =
- CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
+ hlo_profile_printer_data =
+ CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis);
computation_to_profile_idx =
hlo_profile_index_map->computation_to_profile_idx();
}
@@ -619,7 +619,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(aligned_constants),
- std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
+ std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
@@ -698,7 +698,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name,
- std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));
+ std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index f335bd1bbc..802d0a6fb4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -55,9 +55,9 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
- : Executable(std::move(hlo_module), std::move(hlo_profile_printer),
+ : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 50443a5995..267b89a10b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -51,7 +51,7 @@ class CpuExecutable : public Executable {
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 2271af7b24..2924b63659 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -39,4 +39,4 @@ class CpuHloSupportChecker : public HloPassInterface {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h
index 2994642356..664125ecc9 100644
--- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h
+++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
// This file is depended on by kernels that have to build for mobile devices.
// For this reason, we avoid relying on TensorFlow and instead only use the
@@ -71,4 +71,4 @@ class RegisterCustomCallTarget {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index ebd96c4c42..99c5e16db7 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -33,8 +33,14 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
switch (op->opcode()) {
case HloOpcode::kTanh: {
PrimitiveType element_type = op->shape().element_type();
+ bool cast_result_to_fp16 = false;
string function_name;
switch (element_type) {
+ case F16:
+ cast_result_to_fp16 = true;
+ operand_value = ir_builder_->CreateFPCast(operand_value,
+ ir_builder_->getFloatTy());
+ TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "tanhf";
break;
@@ -44,7 +50,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
default:
return Unimplemented("tanh");
}
- // Create function declaration for 'tanhf'.
+ // Create a function declaration.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
llvm_ir::AsStringRef(function_name), operand_value->getType(),
@@ -52,8 +58,12 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
- // Create instruction to call 'tanhf'.
- return ir_builder_->CreateCall(function, operand_value);
+ // Create an instruction to call the function.
+ llvm::Value* result = ir_builder_->CreateCall(function, operand_value);
+ if (cast_result_to_fp16) {
+ result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy());
+ }
+ return result;
}
default:
return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
@@ -63,7 +73,13 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const {
string function_name;
+ bool cast_result_to_fp16 = false;
switch (prim_type) {
+ case F16:
+ cast_result_to_fp16 = true;
+ lhs = ir_builder_->CreateFPCast(lhs, ir_builder_->getFloatTy());
+ rhs = ir_builder_->CreateFPCast(rhs, ir_builder_->getFloatTy());
+ TF_FALLTHROUGH_INTENDED;
case F32:
function_name = "atan2f";
break;
@@ -73,7 +89,7 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
default:
return Unimplemented("atan2");
}
- // Create function declaration for 'atan2'.
+ // Create a function declaration.
llvm::Function* function =
llvm::cast<llvm::Function>(module_->getOrInsertFunction(
llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(),
@@ -81,8 +97,12 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
- // Create instruction to call 'atan2'.
- return ir_builder_->CreateCall(function, {lhs, rhs});
+ // Create an instruction to call the function.
+ llvm::Value* result = ir_builder_->CreateCall(function, {lhs, rhs});
+ if (cast_result_to_fp16) {
+ result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy());
+ }
+ return result;
}
llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
index 9c00d476b1..8008a56df4 100644
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
+++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
#include <memory>
@@ -62,4 +62,4 @@ class ExternalConstantPool {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index cfdf9f4ebc..71e8133189 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -62,6 +62,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -72,6 +73,7 @@ namespace {
using llvm_ir::AsStringRef;
using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
+namespace gtl = tensorflow::gtl;
} // namespace
namespace cpu {
@@ -491,7 +493,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
}
Status IrEmitter::HandleMap(HloInstruction* map) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(map->operands());
+ gtl::ArraySlice<HloInstruction*> operands(map->operands());
HloComputation* function = map->to_apply();
// The called computation should have been emitted previously.
llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
@@ -1225,205 +1227,6 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
return index_with_free_var;
}
-Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
- // The output of BatchNormTraining is a tuple of three element:
- // - An N-dimensional array containing normalized values.
- // - A 1 dimensional array containing the mean value for each feature.
- // - A 1 dimensional array containing the variance value for each feature.
- HloInstruction* operand = batch_norm_training->operands()[0];
- HloInstruction* scale = batch_norm_training->operands()[1];
- HloInstruction* offset = batch_norm_training->operands()[2];
- float epsilon = batch_norm_training->epsilon();
- int64 feature_index = batch_norm_training->feature_index();
- TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) &&
- ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3);
-
- const Shape& output_shape =
- ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0);
- const Shape& feature_shape =
- ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1);
-
- // Reduce vector of the non-feature dimensions.
- std::vector<int64> dimensions_to_reduce;
-
- for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) {
- if (i != feature_index) {
- dimensions_to_reduce.push_back(i);
- }
- }
-
- // Get the second and third allocations in the output tuple, which should be
- // used to store the result of mean and variance value calculation.
- TF_ASSIGN_OR_RETURN(
- const BufferAllocation::Slice slice_mean,
- assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1}));
- TF_ASSIGN_OR_RETURN(
- const BufferAllocation::Slice slice_var,
- assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2}));
- const int feature_count = output_shape.dimensions(feature_index);
- const int size_in_elements = ShapeUtil::ElementsIn(output_shape);
- TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements);
- const int elements_per_feature = size_in_elements / feature_count;
-
- llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape);
- llvm_ir::IrArray mean_array(mean, feature_shape);
-
- llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape);
- llvm_ir::IrArray var_array(var, feature_shape);
-
- // This loop calculates mean and variance for each feature.
- //
- // In theory this could be swapped by multi-output fusion. We will evaluate
- // this when it's ready.
- //
- // For variance calculation, we use a simplified formula so we can fuse the
- // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2.
- TF_RETURN_IF_ERROR(
- llvm_ir::LoopEmitter(
- [&](const llvm_ir::IrArray::Index& index) {
- PrimitiveType element_type = operand->shape().element_type();
- // Used to calculate E(X).
- llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
- "sum_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(element_type));
-
- // Used to calculate E(X^2).
- llvm::Value* sum_square_address =
- llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(element_type, module_),
- "sum_square_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(element_type));
-
- ir_builder_.CreateStore(
- llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0),
- sum_address);
-
- ir_builder_.CreateStore(
- llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0),
- sum_square_address);
-
- llvm_ir::ForLoopNest loops(IrName(batch_norm_training, "inner"),
- &ir_builder_);
-
- const llvm_ir::IrArray::Index reduced_dims_index =
- loops.AddLoopsForShapeOnDimensions(
- operand->shape(), dimensions_to_reduce, "reduction_dim");
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(),
- &ir_builder_);
-
- llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
- llvm_ir::IrArray::Index input_index =
- FillReducedDimensionIndex(reduced_dims_index, index);
- llvm::Value* new_value =
- operand_array.EmitReadArrayElement(input_index, &ir_builder_);
-
- llvm::Value* new_value_square =
- ir_builder_.CreateFMul(new_value, new_value);
-
- llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address);
- llvm::Value* current_sum_square =
- ir_builder_.CreateLoad(sum_square_address);
- // Update sum.
- ir_builder_.CreateStore(
- ir_builder_.CreateFAdd(current_sum, new_value), sum_address);
-
- // Update sum square.
- ir_builder_.CreateStore(
- ir_builder_.CreateFAdd(current_sum_square, new_value_square),
- sum_square_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(),
- &ir_builder_);
-
- llvm::Value* sum = ir_builder_.CreateLoad(sum_address);
- llvm::Value* elements_per_feature_value = llvm::ConstantFP::get(
- ir_builder_.getFloatTy(), elements_per_feature);
- llvm::Value* mean =
- ir_builder_.CreateFDiv(sum, elements_per_feature_value);
- llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean);
- llvm::Value* sum_square =
- ir_builder_.CreateLoad(sum_square_address);
-
- // Var=E(X^2) - E(X)^2.
- llvm::Value* var = ir_builder_.CreateFSub(
- ir_builder_.CreateFDiv(sum_square, elements_per_feature_value),
- mean_square);
-
- var_array.EmitWriteArrayElement(index, var, &ir_builder_);
- return mean;
- },
- mean_array, &ir_builder_)
- .EmitLoop(IrName(batch_norm_training, "mean_var")));
-
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training));
- TF_ASSIGN_OR_RETURN(
- const BufferAllocation::Slice slice,
- assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0}));
-
- llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape);
-
- llvm_ir::IrArray target_array(normalized, output_shape);
-
- AddAliasingInformationToIrArray(*batch_norm_training, &target_array);
-
- TF_RETURN_IF_ERROR(
- llvm_ir::LoopEmitter(
- [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce,
- feature_index, offset, scale](const llvm_ir::IrArray::Index& index) {
- // The following logic normalizes the input value, scales and shifts
- // it:
- //
- // normalized = (input - mean) / sqrt(variance + epsilon)
- // result = normalized * scale + offset
-
- // Current index in the feature dimension.
- llvm_ir::IrArray::Index feature_index_value(1,
- index[feature_index]);
-
- llvm::Value* mean = mean_array.EmitReadArrayElement(
- feature_index_value, &ir_builder_);
- llvm::Value* var = var_array.EmitReadArrayElement(
- feature_index_value, &ir_builder_);
-
- llvm_ir::IrArray operand_array(GetIrArrayFor(operand));
- llvm::Value* input =
- operand_array.EmitReadArrayElement(index, &ir_builder_);
-
- llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd(
- var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon));
- llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()});
- llvm::Value* variance_sqrt =
- ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon});
- llvm::Value* normalized = ir_builder_.CreateFDiv(
- ir_builder_.CreateFSub(input, mean), variance_sqrt);
- llvm_ir::IrArray offset_array(GetIrArrayFor(offset));
- llvm::Value* offset = offset_array.EmitReadArrayElement(
- feature_index_value, &ir_builder_);
- llvm_ir::IrArray scale_array(GetIrArrayFor(scale));
- llvm::Value* scale = scale_array.EmitReadArrayElement(
- feature_index_value, &ir_builder_);
- llvm::Value* result = ir_builder_.CreateFAdd(
- ir_builder_.CreateFMul(normalized, scale), offset);
-
- return result;
- },
- target_array, &ir_builder_)
- .EmitLoop(IrName(batch_norm_training, "normalize")));
-
- llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training),
- {normalized, mean, var}, &ir_builder_, module_);
- return Status::OK();
-}
-
-Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
- // TODO(b/62843645) Implement BatchNormGrad on CPU backend.
- return Unimplemented(
- "BatchNormGrad is not implemented on CPU. See b/62843645.");
-}
-
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
auto param_number = parameter->parameter_number();
@@ -1469,6 +1272,52 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
return Status::OK();
}
+// Returns true if the relative order of the unreduced dimensions stays the same
+// through the reduce operation.
+static bool ReductionPreservesLayout(const HloInstruction& reduce) {
+ DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce);
+
+ // Maps dimensions that were not reduced from their dimension numbers in the
+ // source shape to their dimensions numbers in the destination shape.
+ //
+ // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
+ // [0->0, 3->1].
+ gtl::FlatMap<int64, int64> unreduced_dim_map;
+
+ gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
+ reduce.dimensions().end());
+
+ const Shape& operand_shape = reduce.operand(0)->shape();
+ const Shape& result_shape = reduce.shape();
+
+ int64 delta = 0;
+ for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
+ if (reduced_dims.count(i)) {
+ delta++;
+ } else {
+ InsertOrDie(&unreduced_dim_map, i, i - delta);
+ }
+ }
+
+ // Iterate dimensions minor to major and check that the corresponding
+ // dimensions in the source and target shapes are equivalent.
+ int64 result_dim_idx = 0;
+ for (int64 operand_dim_idx = 0;
+ operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
+ int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
+ if (!reduced_dims.count(operand_dim)) {
+ if (FindOrDie(unreduced_dim_map, operand_dim) !=
+ result_shape.layout().minor_to_major(result_dim_idx++)) {
+ return false;
+ }
+ }
+ }
+
+ CHECK_EQ(result_dim_idx, result_shape.dimensions_size());
+
+ return true;
+}
+
IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
HloComputation* function, string* failure_reason) const {
CHECK_EQ(function->num_parameters(), 2);
@@ -1632,7 +1481,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction(
const ReductionGenerator& reduction_generator,
const llvm_ir::IrArray::Index& output_index,
const ShardedVectorType& accumulator_type, HloInstruction* init_value,
- HloInstruction* arg, tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloInstruction* arg, gtl::ArraySlice<int64> dimensions,
unsigned element_alignment) {
ShardedVector accumulator;
accumulator.reserve(accumulator_type.size());
@@ -1736,8 +1585,12 @@ void IrEmitter::EmitShardedVectorStore(
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
- tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
+ gtl::ArraySlice<int64> dimensions, HloComputation* function,
string* failure_reason) {
+ if (!ReductionPreservesLayout(*reduce)) {
+ return false;
+ }
+
ReductionGenerator reduction_generator =
MatchReductionGenerator(function, failure_reason);
if (!reduction_generator) {
@@ -1881,7 +1734,7 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
- tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
if (!options::VectorizedReduceDisabled(hlo_module_config_)) {
string vectorization_failure_reason;
@@ -2001,7 +1854,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
//
// * Implement the memcpy within the innermost loop.
- tensorflow::gtl::FlatSet<int64> inner_dims;
+ gtl::FlatSet<int64> inner_dims;
for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
break;
@@ -2329,8 +2182,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(
- custom_call->operands());
+ gtl::ArraySlice<HloInstruction*> operands(custom_call->operands());
tensorflow::StringPiece custom_call_target(custom_call->custom_call_target());
llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =
@@ -2461,8 +2313,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
}
StatusOr<bool> IrEmitter::EmitFastConcatenate(
- HloInstruction* concatenate,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloInstruction* concatenate, gtl::ArraySlice<HloInstruction*> operands,
string* failure_reason) {
if (ShouldEmitParallelLoopFor(*concatenate)) {
*failure_reason =
@@ -2601,8 +2452,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
}
Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
- tensorflow::gtl::ArraySlice<HloInstruction*> operands(
- concatenate->operands());
+ gtl::ArraySlice<HloInstruction*> operands(concatenate->operands());
string failure_reason;
TF_ASSIGN_OR_RETURN(
bool successful,
@@ -2915,7 +2765,7 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
// for a single element_type value, and loads it after call.
llvm::Value* IrEmitter::EmitElementFunctionCall(
llvm::Function* function, const Shape& return_shape,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name) {
llvm::Value* return_value_buffer = EmitArrayFunctionCall(
function, return_shape, 1, parameter_addresses, name);
@@ -2935,8 +2785,7 @@ llvm::Value* IrEmitter::EmitElementFunctionCall(
// temps)
// return return_value_buffer -- address of the return value.
void IrEmitter::EmitArrayFunctionCallInto(
- llvm::Function* function,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
ir_builder_.CreateCall(
function, GetArrayFunctionCallArguments(
@@ -2949,7 +2798,7 @@ void IrEmitter::EmitArrayFunctionCallInto(
llvm::Value* IrEmitter::EmitArrayFunctionCall(
llvm::Function* function, const Shape& return_shape, int64 element_count,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name) {
llvm::Value* elements =
llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count);
@@ -3059,8 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source,
Status IrEmitter::ElementTypesSameAndSupported(
const HloInstruction& instruction,
- tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
- tensorflow::gtl::ArraySlice<PrimitiveType> supported_types) {
+ gtl::ArraySlice<const HloInstruction*> operands,
+ gtl::ArraySlice<PrimitiveType> supported_types) {
for (auto operand : operands) {
TF_RET_CHECK(
ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 66f2aeeab3..5094402514 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -125,8 +125,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
- Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
- Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
index 1fd2da4dce..557aa4a6bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
@@ -131,4 +131,4 @@ Status EmitCallToParallelForkJoin(
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h
index 2d29550fd5..f896384115 100644
--- a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h
+++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
#include <memory>
@@ -53,4 +53,4 @@ class Registrar {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
index d1b88b27f0..cd997f0789 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -61,9 +61,9 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
- : Executable(std::move(hlo_module), std::move(hlo_profile_printer),
+ : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
assignment_(std::move(assignment)),
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
index 90ac94ef92..c393e9b8ea 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -55,7 +55,7 @@ class ParallelCpuExecutable : public Executable {
std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>>
aligned_constants,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~ParallelCpuExecutable() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 9335d2818e..ce92e36a94 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@@ -70,4 +70,4 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index 5801ec8d27..7140dabe51 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -99,4 +99,4 @@ class ParallelTaskAssigner : public HloPassInterface {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
index fcf1cc6207..1cf0ec6e3d 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
#include "tensorflow/core/platform/types.h"
@@ -30,4 +30,4 @@ extern void __xla_cpu_runtime_ParallelForkJoin(
} // extern "C"
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h
index cb7e0a81f0..1bd8dfb377 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
#include "tensorflow/core/platform/types.h"
@@ -42,4 +42,4 @@ void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m,
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h
index 7a2d00421c..33d02b70e6 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition.h
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
#include <vector>
@@ -102,4 +102,4 @@ class ShapePartitionIterator {
} // namespace cpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 5403bf48b7..de5e9b4119 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -47,7 +47,7 @@ namespace cpu {
namespace {
// A simple SymbolResolver that delegates to the host dynamic linker.
-class SimpleResolver : public llvm::JITSymbolResolver {
+class SimpleResolver : public llvm::LegacyJITSymbolResolver {
public:
explicit SimpleResolver(ExternalConstantPool* external_constant_pool)
: external_constant_pool_(external_constant_pool) {}
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index 5ff0ab34ea..1959b687f1 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -41,4 +41,4 @@ class DotDecomposer : public HloPassInterface {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 21e7fbea29..90481c7a88 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -73,7 +73,7 @@ StatusOr<std::unique_ptr<ShapedBuffer>> Executable::ExecuteOnStreamWrapper(
std::unique_ptr<HloExecutionProfile> profile_ptr =
module_config().debug_options().xla_hlo_profile() &&
hlo_profiling_enabled()
- ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer(),
+ ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(),
&hlo_profile_index_map())
: nullptr;
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 5ecfdffe21..0aee535ee7 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -44,13 +44,14 @@ namespace xla {
// interface that is used for launching compiled programs across platforms.
class Executable {
public:
- explicit Executable(std::unique_ptr<const HloModule> hlo_module,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
- std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
+ explicit Executable(
+ std::unique_ptr<const HloModule> hlo_module,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
+ std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: hlo_module_(std::move(hlo_module)),
- hlo_profile_printer_(std::move(hlo_profile_printer)),
+ hlo_profile_printer_data_(std::move(hlo_profile_printer_data)),
hlo_profile_index_map_(std::move(hlo_profile_index_map)) {
- CHECK_EQ(hlo_profile_printer_.get() == nullptr,
+ CHECK_EQ(hlo_profile_printer_data_.get() == nullptr,
hlo_profile_index_map_.get() == nullptr);
}
virtual ~Executable() {}
@@ -116,9 +117,9 @@ class Executable {
"Equality test on this executable is not implemented.");
}
- const HloProfilePrinter& hlo_profile_printer() const {
+ const HloProfilePrinterData& hlo_profile_printer_data() const {
CHECK(hlo_profiling_enabled());
- return *hlo_profile_printer_;
+ return *hlo_profile_printer_data_;
}
const HloProfileIndexMap& hlo_profile_index_map() const {
@@ -129,7 +130,9 @@ class Executable {
// Returns whether this executable was compiled with HLO profilings support
// enabled. If not, the caller should not expect an hlo_execution_profile
// passed to ExecuteOnStream above to be populated during execution.
- bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
+ bool hlo_profiling_enabled() const {
+ return hlo_profile_printer_data_ != nullptr;
+ }
const HloModule& module() const { return *hlo_module_; }
@@ -179,7 +182,7 @@ class Executable {
// execution.
int64 execution_count_ = 0;
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer_;
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d7ca0f6846..3c3328b9cd 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -228,6 +228,7 @@ cc_library(
cc_library(
name = "gpu_executable",
srcs = [
+ "conditional_thunk.cc",
"convolution_thunk.cc",
"copy_thunk.cc",
"cudnn_batchnorm_thunk.cc",
@@ -243,6 +244,7 @@ cc_library(
"while_thunk.cc",
],
hdrs = [
+ "conditional_thunk.h",
"convolution_thunk.h",
"copy_thunk.h",
"cudnn_batchnorm_thunk.h",
@@ -475,6 +477,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
+ "//tensorflow/compiler/xla/service:hlo_element_type_converter",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
new file mode 100644
index 0000000000..790ca535b1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+ConditionalThunk::ConditionalThunk(
+ const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kConditional, hlo),
+ predicate_buffer_index_(predicate_buffer_index),
+ true_operand_buffer_index_(true_operand_buffer_index),
+ false_operand_buffer_index_(false_operand_buffer_index),
+ true_thunk_(std::move(true_thunk_sequence), hlo),
+ false_thunk_(std::move(false_thunk_sequence), hlo) {}
+
+Status ConditionalThunk::Initialize(const GpuExecutable& executable) {
+ TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable));
+ TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable));
+ return Status::OK();
+}
+
+Status ConditionalThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ // Copy the predicate value from device.
+ bool predicate;
+ perftools::gputools::DeviceMemoryBase predicate_address =
+ buffer_allocations.GetDeviceAddress(predicate_buffer_index_);
+ stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool));
+
+ Status block_status = stream->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ return InternalError("Failed to retrieve predicate value on stream %p: %s.",
+ stream, block_status.error_message().c_str());
+ }
+
+ // Execute the true or the false computation depending on the value of the
+ // predicate.
+ if (predicate) {
+ TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ } else {
+ TF_RETURN_IF_ERROR(
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ }
+
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
new file mode 100644
index 0000000000..7725c46a3b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// ConditionalThunk implements the conditional instruction on GPU by reading the
+// predicate of the conditional and executing the true or the false computation
+// depending on the value of the predicate.
+//
+// ConditionalThunk assumes that the buffers of the conditional result and the
+// result of the true and false computations share the same allocation. Also,
+// the buffers of the true operand of the conditional and that of the parameter
+// instruction of the true computation share the same allocation. Similarly, the
+// buffers of the false operand and that of the parameter instruction of the
+// false computation share the same allocation.
+class ConditionalThunk : public Thunk {
+ public:
+ ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence,
+ ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo);
+
+ ConditionalThunk(const ConditionalThunk&) = delete;
+ ConditionalThunk& operator=(const ConditionalThunk&) = delete;
+
+ Status Initialize(const GpuExecutable& executable) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ BufferAllocation::Slice predicate_buffer_index_;
+ BufferAllocation::Slice true_operand_buffer_index_;
+ BufferAllocation::Slice false_operand_buffer_index_;
+ SequentialThunk true_thunk_;
+ SequentialThunk false_thunk_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 4b511cb4bb..5af7a77ea8 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -72,9 +72,27 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
PrimitiveType output_type) const {
// The libdevice math functions differentiate between "double" and "float" by
- // appending an 'f' to the function's name.
+ // appending an 'f' to the function's name. libdevice doesn't have f16 math
+ // functions, so we convert the operands to f32 before calling the function
+ // and then convert the result back to f16.
string munged_callee = callee_name;
+ bool cast_result_to_fp16 = false;
+ std::vector<llvm::Value*> converted_operands(operands.begin(),
+ operands.end());
+ std::vector<PrimitiveType> converted_input_types(input_types.begin(),
+ input_types.end());
switch (output_type) {
+ case F16:
+ cast_result_to_fp16 = true;
+ for (int64 i = 0; i < operands.size(); ++i) {
+ if (input_types[i] == F16) {
+ converted_operands[i] = ir_builder_->CreateFPCast(
+ converted_operands[i], ir_builder_->getFloatTy());
+ converted_input_types[i] = F32;
+ }
+ }
+ output_type = F32;
+ TF_FALLTHROUGH_INTENDED;
case F32:
StrAppend(&munged_callee, "f");
break;
@@ -84,7 +102,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
return Unimplemented("Bad type for libdevice math call: %s",
PrimitiveType_Name(output_type).c_str());
}
- return EmitMathCall(munged_callee, operands, input_types, output_type);
+ llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
+ converted_input_types, output_type)
+ .ValueOrDie();
+ if (cast_result_to_fp16) {
+ result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy());
+ }
+ return result;
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
@@ -92,10 +116,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
PrimitiveType output_type) const {
- // llvm intrinsics differentiate between float/double functions via the ".f32"
- // and ".f64" suffixes.
+ // llvm intrinsics differentiate between half/float/double functions via
+ // the suffixes ".f16", ".f32" and ".f64".
string munged_callee = callee_name;
switch (output_type) {
+ case F16:
+ StrAppend(&munged_callee, ".f16");
+ break;
case F32:
StrAppend(&munged_callee, ".f32");
break;
@@ -233,12 +260,6 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
- case HloOpcode::kFloor:
- return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type},
- output_type);
- case HloOpcode::kCeil:
- return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type},
- output_type);
case HloOpcode::kTanh:
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index 525a2af941..832494d17e 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
#include <vector>
@@ -49,4 +49,4 @@ class ForThunk : public Thunk {
} // namespace gpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index bd720f8584..4c523a66de 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -44,4 +44,4 @@ class FusionMerger : public HloPassInterface {
} // namespace gpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 89acac2c3f..0cca3ca092 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -58,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
@@ -137,6 +138,10 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module) {
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
+ // Convert BF16 operations to F32 operations so that the GPU backend can
+ // support BF16 operations without directly implementing a BF16 lowering for
+ // most ops.
+ pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
pipeline.AddPass<DotDecomposer>();
{
auto& pass =
@@ -281,14 +286,16 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
return;
}
- // ptxas 9.0 before 9.0.276 miscompiles some address calculations with large
- // offsets (e.g. "load ptr + large_constant"), b/70245379.
- if (vmaj == 9 && vmin == 0 && vdot < 276) {
+ // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some
+ // address calculations with large offsets (e.g. "load ptr + large_constant"),
+ // b/70245379.
+ if ((vmaj == 9 && vmin == 0 && vdot < 276) ||
+ (vmaj == 9 && vmin == 1 && vdot < 121)) {
LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "."
<< vmin << "." << vdot
- << ", which is in range [9.0.0, 9.0.276). These versions are "
- "known to miscompile XLA code, leading to incorrect "
- "results or invalid-address errors.";
+ << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). "
+ "These versions are known to miscompile XLA code, leading "
+ "to incorrect results or invalid-address errors.";
}
}
@@ -309,16 +316,24 @@ void WarnIfBadDriverJITVersion() {
}
se::cuda::DriverVersion version = version_or_status.ValueOrDie();
- // The driver JIT in 384 before 384.108 miscompiles some address
+ // The following versions of the driver JIT miscompile some address
// calculations with large offsets (e.g. "load ptr + large_constant"),
- // b/70245379.
- if (std::get<0>(version) == 384 && std::get<1>(version) < 108) {
+ // b/70245379:
+ //
+ // - 384.x before 384.108
+ // - 387.x before 387.40
+ // - 390.x before 390.10.
+ auto vmaj = std::get<0>(version);
+ auto vmin = std::get<1>(version);
+ if ((vmaj == 384 && vmin < 108) || //
+ (vmaj == 387 && vmin < 40) || //
+ (vmaj == 390 && vmin < 10)) {
LOG(WARNING)
<< "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
<< se::cuda::DriverVersionToString(version)
- << ", which is in range [384.0.0, 384.108.0). These versions are "
- "known to miscompile XLA code, leading to incorrect results or "
- "invalid-address errors.";
+ << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + "
+ "[390.0.0, 390.10.0). These versions are known to miscompile XLA "
+ "code, leading to incorrect results or invalid-address errors.";
}
});
}
@@ -578,14 +593,14 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
XLA_VLOG_LINES(2, thunk_schedule->ToString());
std::unique_ptr<HloProfileIndexMap> profile_index_map;
- std::unique_ptr<HloProfilePrinter> profile_printer;
+ std::unique_ptr<HloProfilePrinterData> profile_printer;
if (module->config().hlo_profiling_enabled()) {
HloCostAnalysis cost_analysis(ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis));
profile_index_map = MakeUnique<HloProfileIndexMap>(*module);
profile_printer =
- CreateHloProfilePrinter(*profile_index_map, cost_analysis);
+ CreateHloProfilePrinterData(*profile_index_map, cost_analysis);
}
auto* gpu_executable = new GpuExecutable(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
index 572c856282..eb1ca4c6c9 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
#include "tensorflow/compiler/xla/types.h"
@@ -28,4 +28,4 @@ extern const int64 kCudaMallocAlignBytes;
} // namespace gpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index e67087d822..e3b493c663 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -36,7 +36,7 @@ namespace gpu {
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
HloInstruction* hlo) {
- HloInstruction*& copy = inserted_copies_[hlo];
+ HloInstruction*& copy = hlo_to_copy_map_[hlo];
if (copy == nullptr) {
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
}
@@ -86,27 +86,34 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
- // Init values of a while node cannot be constants. Insert copies for any
- // constants found at the operand of a while.
- tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
+ // Init values of while and conditional nodes cannot be constants. Insert
+ // copies for any constants found at the operands of these nodes.
+ tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kWhile) {
+ if (instruction->opcode() != HloOpcode::kWhile &&
+ instruction->opcode() != HloOpcode::kConditional) {
continue;
}
- for (auto& pair :
- dataflow->GetInstructionValueSet(instruction->operand(0))) {
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction()->opcode() ==
- HloOpcode::kConstant &&
- !ContainsKey(copied_constants, value->defining_instruction())) {
- HloInstruction* constant = value->defining_instruction();
- TF_ASSIGN_OR_RETURN(HloInstruction * copy,
- FindOrInsertCopy(constant));
- TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
- copied_constants.insert(constant);
- changed = true;
+ for (auto operand : instruction->operands()) {
+ // Skip the operands that have already been replaced with a copy in a
+ // previous iteration (which is possible when a constant is used as an
+ // operand in multiple places).
+ if (ContainsKey(inserted_copies, operand)) {
+ continue;
+ }
+ for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
+ const HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction()->IsConstant() &&
+ !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) {
+ HloInstruction* constant = value->defining_instruction();
+ TF_ASSIGN_OR_RETURN(HloInstruction * copy,
+ FindOrInsertCopy(constant));
+ TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
+ inserted_copies.insert(copy);
+ changed = true;
+ }
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 4d77f337e6..0c6f9b511f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -32,13 +32,13 @@ class GpuCopyInsertion : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
protected:
- // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
+ // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted to materialize operands of library
// calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
+ tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 51d164cdf4..f5d67b9ea9 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -116,9 +116,9 @@ GpuExecutable::GpuExecutable(
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
- : Executable(std::move(hlo_module), std::move(hlo_profile_printer),
+ : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
ptx_(ptx),
cubin_(cubin),
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 00da64dfad..b19cfd43de 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -54,7 +54,7 @@ class GpuExecutable : public Executable {
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::unique_ptr<const HloModule> hlo_module,
std::unique_ptr<const BufferAssignment> assignment,
- std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
// This should be called after set_ir_module_string.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index d9550f81b5..d63e213d2b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -39,4 +39,4 @@ class GpuHloSupportChecker : public HloPassInterface {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 095c3df3bf..23b72c3f71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -758,37 +758,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleConditional(HloInstruction* conditional) {
- auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
-
- llvm::Value* conditional_result = GetBasePointer(*conditional);
-
- llvm::LoadInst* pred_value = ir_builder_.CreateLoad(
- GetBasePointer(*pred),
- llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value")));
- llvm::Value* pred_cond = ir_builder_.CreateICmpNE(
- pred_value,
- llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
- llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate")));
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- pred_cond, IrName(conditional, "if_then_else"), &ir_builder_);
-
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->true_computation(), {GetBasePointer(*true_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.false_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->false_computation(), {GetBasePointer(*false_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.after_block, &ir_builder_);
- return Status::OK();
-}
-
llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 39bafaa346..3aa178410f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -96,7 +96,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleRng(HloInstruction* random) override;
- Status HandleConditional(HloInstruction* conditional) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
@@ -367,6 +366,11 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
+ // Returns a ConditionalThunk that executes the thunk sequence for
+ // 'true_computation' or 'false_computation' depending on the value of the
+ // predicate in the given conditional instruction.
+ std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
+
Status Postprocess(HloInstruction* hlo) override;
// Returns the last generated thunk.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index be35351e87..fc8783e753 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
@@ -272,8 +273,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
- thunk_sequence_->push_back(BuildKernelThunk(conditional));
- return IrEmitter::HandleConditional(conditional);
+ thunk_sequence_->emplace_back(BuildConditionalThunk(conditional));
+ return Status::OK();
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
@@ -2102,6 +2103,24 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
namespace {
+// Checks that the buffers corresponding to the given two HLOs share the same
+// allocation.
+Status CheckHloBuffersShareAllocation(
+ const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
+ const BufferAssignment& buffer_assignment) {
+ const BufferAllocation::Slice slice_a =
+ buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
+ const BufferAllocation::Slice slice_b =
+ buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
+ if (slice_a != slice_b) {
+ return InternalError(
+ "instruction %s %s does not share allocation with instruction %s %s",
+ a->ToString().c_str(), slice_a.ToString().c_str(),
+ b->ToString().c_str(), slice_b.ToString().c_str());
+ }
+ return Status::OK();
+}
+
// Checks that all buffers used during while loop iteration share the same
// buffer allocation. This includes buffers for while result, while init
// operand, condition parameter, body parameter and body result.
@@ -2111,37 +2130,65 @@ Status CheckWhileBuffersShareAllocation(
const BufferAssignment& buffer_assignment) {
return ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),
- [&buffer_assignment, &xla_while](const Shape& /*subshape*/,
- const ShapeIndex& index) -> Status {
- auto check = [&buffer_assignment](const HloInstruction* a,
- const HloInstruction* b,
- const ShapeIndex& index) -> Status {
- const BufferAllocation::Slice slice_a =
- buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
- const BufferAllocation::Slice slice_b =
- buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
- if (slice_a != slice_b) {
- return InternalError(
- "instruction %s %s does not share allocation with "
- "instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
- }
- return Status::OK();
- };
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
const HloInstruction* condition_parameter =
xla_while->while_condition()->parameter_instruction(0);
const HloComputation* body = xla_while->while_body();
const HloInstruction* body_parameter = body->parameter_instruction(0);
const HloInstruction* body_result = body->root_instruction();
- TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
- TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_result, index));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, xla_while->operand(0), index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, condition_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_result, index, buffer_assignment));
return Status::OK();
});
}
+// Checks that the buffers used in a conditional instruction are shared with the
+// operands and result as follows:
+// * The result buffer of the conditional should share the allocation with the
+// result buffers of the true and false computations.
+// * The buffer of operand 1 should share the allocation with the buffer of
+// the parameter 0 instruction of the true computation.
+// * The buffer of operand 2 should share the allocation with the buffer of
+// the parameter 0 instruction of the false computation.
+Status CheckConditionalBuffersShareAllocation(
+ const HloInstruction* conditional,
+ const BufferAssignment& buffer_assignment) {
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->true_computation()->root_instruction(),
+ index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->false_computation()->root_instruction(),
+ index, buffer_assignment));
+ return Status::OK();
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(1)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(1),
+ conditional->true_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(2)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(2),
+ conditional->false_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ return Status::OK();
+}
+
} // namespace
std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
@@ -2184,6 +2231,31 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_body.ConsumeThunkSequence(), hlo);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
+ const HloInstruction* hlo) {
+ // Check that the buffers used in conditional are shared with the operands and
+ // result appropriately.
+ TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
+ hlo, ir_emitter_context_->buffer_assignment()));
+
+ HloComputation* true_computation = hlo->true_computation();
+ IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true));
+
+ HloComputation* false_computation = hlo->false_computation();
+ IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false));
+
+ return MakeUnique<ConditionalThunk>(
+ GetAllocationSlice(*hlo->operand(0)),
+ GetAllocationSlice(*hlo->operand(1)),
+ GetAllocationSlice(*hlo->operand(2)),
+ std::move(*ir_emitter_true.ConsumeThunkSequence()),
+ std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo);
+}
+
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index c29fee0879..2923a79af0 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -28,7 +28,7 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
CHECK_EQ(HloOpcode::kConvolution, conv.opcode());
- return window_util::HasEvenPadding(conv.window()) &&
+ return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
}
@@ -43,7 +43,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums,
HloInstruction* input) {
HloComputation* computation = input->parent();
- if (!window_util::HasEvenPadding(conv_window) ||
+ if (!window_util::HasSymmetricPadding(conv_window) ||
window_util::HasBaseDilation(conv_window)) {
// If padding is uneven or has dilation, we insert a kPad instruction that
// applies positive padding and dilation.
@@ -190,7 +190,7 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
bool PadInsertion::CanonicalizeBackwardFilterConvolution(
HloInstruction* backward_conv) {
- if (window_util::HasEvenPadding(backward_conv->window())) {
+ if (window_util::HasSymmetricPadding(backward_conv->window())) {
return false;
}
@@ -285,7 +285,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* backward_conv) {
- if (window_util::HasEvenPadding(backward_conv->window())) {
+ if (window_util::HasSymmetricPadding(backward_conv->window())) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 625c3f8bea..2c3032d79b 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -41,6 +41,7 @@ class GpuExecutable;
class Thunk {
public:
enum class Kind {
+ kConditional,
kConvolution,
kCopy,
kCudnnBatchNormBackward,
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h
index a4f527fce0..fe3a954e18 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.h
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -40,4 +40,4 @@ StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
} // namespace gpu
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index 1773bb401d..c782d1b0ad 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -54,45 +54,96 @@ bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
return false;
}
+// Finds out the Tuple Shape of the new instruction after converting the element
+// type of the operands of the original instruction from `from_type` to
+// `to_type`.
+//
+// This routine assumes the resulting `shape` of the original instruction is a
+// non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
+// kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
+// results with tuple shapes, and this routine is only called to convert the
+// result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
+Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
+ PrimitiveType to_type) {
+ std::vector<Shape> new_tuple_subshapes;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
+ CHECK(!ShapeUtil::IsTuple(subshape));
+ if (subshape.element_type() == from_type) {
+ subshape = ShapeUtil::ChangeElementType(subshape, to_type);
+ }
+ new_tuple_subshapes.push_back(subshape);
+ }
+ return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
+}
+
+// Converts the elements of the result of `hlo` to produce a new tuple with
+// shape `to_shape`.
+//
+// This routine assumes `hlo` is an instruction that produces a non-nested Tuple
+// as a result.
+HloInstruction* ConvertTupleElements(HloInstruction* hlo,
+ const Shape& to_shape) {
+ const Shape& shape = hlo->shape();
+ HloComputation* computation = hlo->parent();
+ std::vector<HloInstruction*> tuple_elements;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
+ HloInstruction* element = computation->AddInstruction(
+ HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
+ const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
+ CHECK(!ShapeUtil::IsTuple(ele_shape));
+ if (ele_shape.element_type() != to_ele_shape.element_type()) {
+ element = computation->AddInstruction(
+ HloInstruction::CreateConvert(to_ele_shape, element));
+ }
+ tuple_elements.push_back(element);
+ }
+ return computation->AddInstruction(
+ HloInstruction::CreateTuple(tuple_elements));
+}
+
} // namespace
HloElementTypeConverter::HloElementTypeConverter(
PrimitiveType eliminate_type, PrimitiveType replace_with_type)
: eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
+// This routine converts the arithmetic operations in the given module that use
+// eliminate_type_ to operations that use replace_with_type_.
StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
XLA_VLOG_LINES(
3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
+
+ if (eliminate_type_ == replace_with_type_) {
+ return false;
+ }
+
bool changed = false;
for (auto* computation : module->computations()) {
for (auto* hlo : computation->MakeInstructionPostOrder()) {
+ const auto opcode = hlo->opcode();
// These are ops where it does not make sense to convert them.
- if (hlo->opcode() == HloOpcode::kParameter ||
- hlo->opcode() == HloOpcode::kConstant ||
- hlo->opcode() == HloOpcode::kTuple ||
- hlo->opcode() == HloOpcode::kConvert ||
- hlo->opcode() == HloOpcode::kGetTupleElement ||
- hlo->opcode() == HloOpcode::kInfeed ||
- hlo->opcode() == HloOpcode::kOutfeed) {
+ if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
+ opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
+ opcode == HloOpcode::kGetTupleElement ||
+ opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
continue;
}
// We cannot change a CustomCall since we have no way of adjusting the
// called binary to expect the updated type.
- if (hlo->opcode() == HloOpcode::kCustomCall) {
+ if (opcode == HloOpcode::kCustomCall) {
continue;
}
// These are ops with embedded computations where it suffices to convert
// the embedded computations instead of converting the ops themselves.
- if (hlo->opcode() == HloOpcode::kWhile ||
- hlo->opcode() == HloOpcode::kCall ||
- hlo->opcode() == HloOpcode::kFusion ||
- hlo->opcode() == HloOpcode::kMap ||
- hlo->opcode() == HloOpcode::kReduce ||
- hlo->opcode() == HloOpcode::kReduceWindow ||
- hlo->opcode() == HloOpcode::kSelectAndScatter ||
- hlo->opcode() == HloOpcode::kConditional) {
+ if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
+ opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
+ opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
+ opcode == HloOpcode::kSelectAndScatter ||
+ opcode == HloOpcode::kConditional) {
continue;
}
TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
@@ -106,6 +157,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
continue;
}
+ // Handle instructions that perform arithmetic operations and contain
+ // operands with eliminate_type_.
+ //
+ // First, convert the operands with eliminate_type_ to operands with
+ // replace_with_type_.
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
if (operand->shape().element_type() == eliminate_type_) {
@@ -114,6 +170,10 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
new_operands.push_back(operand);
}
+ // Then find out the result type of the new instruction with the same
+ // opcode but using the converted operands, create the new instruction,
+ // and convert the result of the new instruction back to match the result
+ // type of the original instruction.
HloInstruction* new_hlo;
if (hlo->shape().element_type() == eliminate_type_) {
Shape shape =
@@ -121,10 +181,20 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
new_hlo = computation->AddInstruction(
hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule()));
new_hlo = ToElementType(new_hlo, eliminate_type_);
+ } else if (ShapeUtil::IsTuple(hlo->shape())) {
+ Shape old_shape = hlo->shape();
+ Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
+ replace_with_type_);
+ new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
+ new_shape, new_operands, hlo->GetModule()));
+ // Convert the elements of the result of `new_hlo` to produce a new
+ // tuple with shape `old_shape`.
+ new_hlo = ConvertTupleElements(new_hlo, old_shape);
} else {
new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands(
hlo->shape(), new_operands, hlo->GetModule()));
}
+
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo));
changed = true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
new file mode 100644
index 0000000000..cb94d9f19b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
@@ -0,0 +1,121 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+namespace op = xla::testing::opcode_matchers;
+
+class HloElementTypeConverterTest : public HloTestBase {
+ public:
+ std::unique_ptr<HloModule> CreateModuleFromHloString(
+ const string& hlo_string) {
+ return HloRunner::CreateModuleFromString(hlo_string,
+ GetDebugOptionsForTest())
+ .ValueOrDie();
+ }
+};
+
+TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) {
+ const string& hlo_string = R"(
+ HloModule custom_call
+ ENTRY CustomCall {
+ constant = bf16[1]{0} constant({12345})
+ ROOT custom-call = bf16[1,2,3]{0,2,1} custom-call(constant),
+ custom_call_target="foo"
+ }
+ )";
+ auto module = CreateModuleFromHloString(hlo_string);
+ HloElementTypeConverter type_converter(BF16, F32);
+ TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
+ EXPECT_FALSE(converted);
+}
+
+TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
+ const string& hlo_string = R"(
+ HloModule InfeedOutfeed
+ ENTRY RoundTrip16MiBR1.v2 {
+ ROOT infeed = bf16[4]{0} infeed()
+ outfeed = () outfeed(infeed)
+ }
+ )";
+ auto module = CreateModuleFromHloString(hlo_string);
+ HloElementTypeConverter type_converter(BF16, F32);
+ TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
+ EXPECT_FALSE(converted);
+}
+
+TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) {
+ const string& hlo_string = R"(
+ HloModule NestedTuples
+ ENTRY NestedTuples.v5 {
+ constant.4 = bf16[] constant(42)
+ constant.2 = f32[2]{0} constant({1, 2})
+ constant.3 = bf16[] constant(42)
+ add = bf16[] add(constant.2, constant.3)
+ tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add)
+ constant.5 = bf16[2]{0} constant({22, 44})
+ ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5)
+ }
+ )";
+
+ auto module = CreateModuleFromHloString(hlo_string);
+ HloElementTypeConverter type_converter(BF16, F32);
+ TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
+ EXPECT_TRUE(converted);
+ const HloInstruction* bf16_op =
+ module->entry_computation()->root_instruction()->operand(0)->operand(1);
+ EXPECT_THAT(bf16_op, op::Convert(op::Add(op::Constant(), op::Convert())));
+}
+
+TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) {
+ const string& hlo_string = R"(
+ HloModule BatchNormGrad
+ ENTRY BatchNormGrad.v6 {
+ constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/
+ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0},
+ {0} }, { /*i1=1*/ {0}, {0} } } })
+ constant.5 = bf16[2]{0} constant({1, 1})
+ constant.6 = bf16[2]{0} constant({0, 0})
+ constant.7 = bf16[2]{0} constant({1, 1})
+ constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/
+ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/
+ {5}, {6} }, { /*i1=1*/ {7}, {8} } } })
+ ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0})
+ batch-norm-grad(constant.4, constant.5, constant.6, constant.7,
+ constant.8), epsilon=0, feature_index=2
+ }
+ )";
+
+ auto module = CreateModuleFromHloString(hlo_string);
+ HloElementTypeConverter type_converter(BF16, F32);
+ TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get()));
+ EXPECT_TRUE(converted);
+ const HloInstruction* tuple_instr =
+ module->entry_computation()->root_instruction();
+ ::testing::Matcher<const ::xla::HloInstruction*> batch_norm =
+ op::BatchNormGrad();
+ EXPECT_THAT(tuple_instr,
+ op::Tuple(op::Convert(op::GetTupleElement(batch_norm, 0)),
+ op::Convert(op::GetTupleElement(batch_norm, 1)),
+ op::Convert(op::GetTupleElement(batch_norm, 2))));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 3a846a7529..e3f5c17e35 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -166,6 +167,34 @@ StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
return std::move(result);
}
+// For one particular placement of a window in a base shape (the placement is
+// represented as `window_count_index`), iterates inside the window. Translates
+// the window index into base index. If the base index is within bound, call `f`
+// with the base index.
+void IterateThroughWindow(
+ const Shape& window_shape, const Window& window, const Shape& base_shape,
+ const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const std::function<void(const std::vector<int64>&)>& f) {
+ const int64 rank = ShapeUtil::Rank(base_shape);
+ DimensionVector window_index(rank);
+ std::fill(window_index.begin(), window_index.end(), 0);
+ do {
+ std::vector<int64> base_index(rank);
+ bool out_of_bound = false;
+ for (int64 i = 0; i < rank; ++i) {
+ base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] - window.dimensions(i).padding_low();
+ if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
+ out_of_bound = true;
+ break;
+ }
+ }
+ if (!out_of_bound) {
+ f(base_index);
+ }
+ } while (IndexUtil::BumpIndices(window_shape, &window_index));
+}
+
} // namespace
template <typename ReturnT, typename ElementwiseT>
@@ -945,14 +974,21 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
out_index[output_spatial_dim] * window_dim.stride() -
window_dim.padding_low() +
rhs_spatial_index[ki] * window_dim.window_dilation();
- // Skip if the lhs (input) index is to be dilated.
- if (undilated_index % window_dim.base_dilation() != 0) {
+ // Skip if the lhs (input) index is to be dilated. As an
+ // optimization, skip this mod if there's no dilation.
+ if (window_dim.base_dilation() > 1 &&
+ undilated_index % window_dim.base_dilation() != 0) {
goto cnt;
}
- // Calculate the actual lhs (input) index after dilation.
- lhs_index[input_spatial_dim] =
- undilated_index / window_dim.base_dilation();
+ // Calculate the actual lhs (input) index after dilation. As an
+ // optimization, skip this integer divide if there's no dilation.
+ if (window_dim.base_dilation() > 1) {
+ lhs_index[input_spatial_dim] =
+ undilated_index / window_dim.base_dilation();
+ } else {
+ lhs_index[input_spatial_dim] = undilated_index;
+ }
// Skip if input index is not in bound.
if (!(lhs_index[input_spatial_dim] >= 0 &&
@@ -1413,6 +1449,111 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
+ auto operand = select_and_scatter->operand(0);
+ auto source = select_and_scatter->operand(1);
+ const Window& window = select_and_scatter->window();
+
+ const Literal& init_literal =
+ parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
+ auto init_scalar = init_literal.Get<ReturnT>({});
+
+ auto result = Literal::CreateFromShape(select_and_scatter->shape());
+
+ // Initialize result array with the init value.
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ return init_scalar;
+ }));
+
+ std::vector<int64> window_dimension_sizes;
+ for (const auto& window_dimension : window.dimensions()) {
+ window_dimension_sizes.push_back(window_dimension.size());
+ }
+ const Shape window_shape = ShapeUtil::MakeShape(
+ operand->shape().element_type(), window_dimension_sizes);
+
+ HloComputation* select = select_and_scatter->select();
+ HloComputation* scatter = select_and_scatter->scatter();
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
+
+ int64 rank = ShapeUtil::Rank(operand_literal.shape());
+
+ HloEvaluator embedded_evaluator;
+ DimensionVector source_index(rank);
+
+ std::fill(source_index.begin(), source_index.end(), 0);
+ do {
+ // For each element in `source`, we place a window in `operand`. For each
+ // window placement, we iterate inside the window twice:
+ //
+ // 1. Find the selected index by applying `select` function to all
+ // elements. E.g., If the `select` function is GreaterEqual, the first
+ // iteration through the window finds the biggest value and returns its
+ // index.
+ //
+ // 2. Using the selected index, scatter value from `source` to result. We
+ // do this by iterating through the window, and compare each index with
+ // the selected index.
+ tensorflow::gtl::optional<ReturnT> selected_val;
+ tensorflow::gtl::optional<std::vector<int64>> selected_index;
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+ if (!selected_val) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
+ const auto selected_val_literal =
+ Literal::CreateR0<ReturnT>(*selected_val);
+
+ const std::vector<const Literal*> args = {
+ curr_val_literal.get(), selected_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*select, args)
+ .ConsumeValueOrDie();
+ bool selected = computed_result->Get<bool>({});
+ if (selected) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ embedded_evaluator.ResetVisitStates();
+ });
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ if (std::equal(operand_index.begin(), operand_index.end(),
+ selected_index->begin())) {
+ auto source = source_literal.Get<ReturnT>(source_index);
+ auto scattered = result->Get<ReturnT>(operand_index);
+ const auto source_literal = Literal::CreateR0<ReturnT>(source);
+ const auto scattered_literal =
+ Literal::CreateR0<ReturnT>(scattered);
+
+ const std::vector<const Literal*> args = {
+ source_literal.get(), scattered_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
+ .ConsumeValueOrDie();
+ result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ // Clear visit states so that the we can use the evaluator again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+ }
+ });
+ } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+
+ parent_->evaluated_[select_and_scatter] = std::move(result);
+ return Status::OK();
+ }
+
Status HandleReduceWindow(HloInstruction* reduce_window) override {
auto operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
@@ -1461,39 +1602,28 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
std::fill(window_index.begin(), window_index.end(), 0);
std::fill(operand_index.begin(), operand_index.end(), 0);
- do {
- bool out_of_bound = false;
- for (int i = 0; i < operand_index.size(); ++i) {
- operand_index[i] =
- output_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
- if (operand_index[i] < 0 ||
- operand_index[i] >= operand_literal.shape().dimensions(i)) {
- out_of_bound = true;
- break;
- }
- }
- if (!out_of_bound) {
- auto curr_val = operand_literal.Get<ReturnT>(operand_index);
-
- // Evaluate computation with specified literal operands.
- const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
- const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
- const std::vector<const Literal*> args = {
- curr_val_literal.get(), result_val_literal.get()};
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
- .ConsumeValueOrDie();
-
- // Clear visit states so that the we can use the evaluate again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
-
- result_val = computed_result->Get<ReturnT>({});
- }
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), output_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+
+ // Evaluate computation with specified literal operands.
+ const auto curr_val_literal =
+ Literal::CreateR0<ReturnT>(curr_val);
+ const auto result_val_literal =
+ Literal::CreateR0<ReturnT>(result_val);
+ const std::vector<const Literal*> args = {
+ curr_val_literal.get(), result_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ .ConsumeValueOrDie();
+
+ // Clear visit states so that the we can use the evaluate again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+
+ result_val = computed_result->Get<ReturnT>({});
+ });
return result_val;
}));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 02bb8b0a47..3b2b697e49 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
#include <memory>
@@ -195,4 +195,4 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index 849aac0b12..f0df93b61d 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -40,83 +40,75 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) {
}
}
-std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
+std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis) {
- using HloComputationInfo = HloProfilePrinter::HloComputationInfo;
- using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo;
-
- HloComputationInfo* computation_infos =
- new HloComputationInfo[hlo_profile_index_map.computation_count()];
-
- // There are two "indices" in play here. The first one is the index of the
- // HloComputationInfo or HloInstructionInfo in the array that contains said
- // HloComputationInfo or HloInstructionInfo. The second index is the index of
- // the HloComputationInfo or HloInstructionInfo in the profile counters array,
- // as decided by hlo_profile_index_map. The latter index is always referred
- // to as "profile_index".
-
- size_t computation_index_in_static_data = 0;
- size_t max_profile_index = hlo_profile_index_map.total_count();
- for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) {
- CHECK_LT(pair.second, max_profile_index);
+ using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
+ using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
+
+ size_t profile_counters_size = hlo_profile_index_map.total_count();
+
+ std::unique_ptr<HloProfilePrinterData> profile_printer_data =
+ MakeUnique<HloProfilePrinterData>();
+ profile_printer_data->set_profile_counters_size(profile_counters_size);
+ profile_printer_data->mutable_computation_infos()->Reserve(
+ hlo_profile_index_map.computation_count());
+
+ const auto& computation_to_profile_idx_map =
+ hlo_profile_index_map.computation_to_profile_idx();
+
+ // computation_to_profile_idx_map's order is not deterministic so create a
+ // deterministic computation_and_profile_idx_list so that we end up with a
+ // deterministic HloProfilePrinterData protobuf.
+
+ std::vector<std::pair<const HloComputation*, int64>>
+ computation_and_profile_idx_list(computation_to_profile_idx_map.begin(),
+ computation_to_profile_idx_map.end());
+
+ // The profile indices were computed deterministically in
+ // HloProfileIndexMap::HloProfileIndexMap.
+ c_sort(computation_and_profile_idx_list,
+ [](const std::pair<const HloComputation*, int64>& left,
+ const std::pair<const HloComputation*, int64>& right) {
+ return left.second < right.second;
+ });
+
+ for (const auto& pair : computation_and_profile_idx_list) {
+ CHECK_LT(pair.second, profile_counters_size);
const HloComputation* computation = pair.first;
- size_t current_computation_index = computation_index_in_static_data++;
HloComputationInfo* computation_info =
- &computation_infos[current_computation_index];
+ profile_printer_data->add_computation_infos();
- computation_info->name = strdup(computation->name().c_str());
- computation_info->profile_index = pair.second;
- computation_info->instructions =
- new HloInstructionInfo[computation->instruction_count()];
- computation_info->instructions_size = computation->instruction_count();
+ computation_info->set_name(computation->name());
+ computation_info->set_profile_index(pair.second);
+ computation_info->mutable_instruction_infos()->Reserve(
+ computation->instruction_count());
- size_t instruction_index_in_static_data = 0;
for (const HloInstruction* hlo : computation->instructions()) {
- HloProfilePrinter::HloInstructionInfo* instruction_info =
- &computation_info->instructions[instruction_index_in_static_data++];
- instruction_info->long_name = strdup(hlo->ToString().c_str());
- instruction_info->short_name = strdup(
- hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str());
- instruction_info->category = strdup(hlo->ToCategory().c_str());
- instruction_info->flop_count = cost_analysis.flop_count(*hlo);
- instruction_info->transcendental_count =
- cost_analysis.transcendental_count(*hlo);
- instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo);
- instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo);
- instruction_info->profile_index =
- hlo_profile_index_map.GetProfileIndexFor(*hlo);
- CHECK_LT(instruction_info->profile_index, max_profile_index);
+ HloInstructionInfo* instruction_info =
+ computation_info->add_instruction_infos();
+ instruction_info->set_long_name(hlo->ToString());
+ instruction_info->set_short_name(
+ hlo->ToString(HloPrintOptions().set_compact_operands(true)));
+ instruction_info->set_category(hlo->ToCategory());
+ instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
+ instruction_info->set_transcendental_count(
+ cost_analysis.transcendental_count(*hlo));
+ instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo));
+ instruction_info->set_optimal_seconds(
+ cost_analysis.optimal_seconds(*hlo));
+ instruction_info->set_profile_index(
+ hlo_profile_index_map.GetProfileIndexFor(*hlo));
}
}
- auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos,
- int64 computation_infos_size) {
- for (int64 i = 0; i < computation_infos_size; i++) {
- HloInstructionInfo* instruction_infos = computation_infos[i].instructions;
- for (int64 j = 0; j < computation_infos[i].instructions_size; j++) {
- // We can't make instruction_infos[j].long_name etc. non-const pointers
- // since they may point into static storage, so we have a const_cast
- // here.
- free(const_cast<char*>(instruction_infos[j].long_name));
- free(const_cast<char*>(instruction_infos[j].short_name));
- free(const_cast<char*>(instruction_infos[j].category));
- }
- delete[] instruction_infos;
- free(const_cast<char*>(computation_infos[i].name));
- }
- delete[] computation_infos;
- };
-
- return MakeUnique<HloProfilePrinter>(
- computation_infos, hlo_profile_index_map.computation_count(),
- /*profile_counters_size=*/max_profile_index, deleter);
+ return profile_printer_data;
}
HloExecutionProfile::HloExecutionProfile(
- const HloProfilePrinter* hlo_profile_printer,
+ const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map)
- : hlo_profile_printer_(*hlo_profile_printer),
+ : hlo_profile_printer_data_(*hlo_profile_printer_data),
hlo_profile_index_map_(*hlo_profile_index_map),
profile_counters_(
/*count*/ hlo_profile_index_map_.total_count(),
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h
index 1a6b069609..6fb91b9bef 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.h
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h
@@ -77,8 +77,8 @@ class HloProfileIndexMap {
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx_;
};
-// Create an instance of `HloProfilePrinter` that owns its memory.
-std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
+// Create an instance of `HloProfilePrinterData`.
+std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
const HloProfileIndexMap& hlo_profile_index_map,
const HloCostAnalysis& cost_analysis);
@@ -90,7 +90,7 @@ class HloExecutionProfile {
public:
using DeviceDescription = perftools::gputools::DeviceDescription;
- HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer,
+ HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data,
const HloProfileIndexMap* hlo_profile_index_map);
// Record how many cycles this HLO took to execute.
@@ -117,11 +117,10 @@ class HloExecutionProfile {
// debugging; e.g. emits cycle counts, execution time at the nominal device
// frequency, and the effective throughput given the provided cost_analysis
// for the operations in a given computation. Returns an empty string if it
- // wasn't possible to generate a printable version. cost_analysis should be a
- // clean analysis that can be used to visit the computation.
+ // wasn't possible to generate a printable version.
string ToString(const DeviceDescription& device_description) const {
- return hlo_profile_printer_.ToString(profile_counters_.data(),
- device_description.clock_rate_ghz());
+ return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(),
+ device_description.clock_rate_ghz());
}
std::vector<int64>* mutable_profile_counters() { return &profile_counters_; }
@@ -130,7 +129,7 @@ class HloExecutionProfile {
}
private:
- const HloProfilePrinter& hlo_profile_printer_;
+ const HloProfilePrinterData& hlo_profile_printer_data_;
const HloProfileIndexMap& hlo_profile_index_map_;
// Stores per-Hlo profile counters. This is the only thing that changes when
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index b1e6729e2b..a0cb28246d 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -73,8 +73,8 @@ TEST_F(HloExecutionProfileTest, Basic) {
HloCostAnalysis cost_analysis(shape_size_function);
HloProfileIndexMap profile_index_map(*hlo_module);
- std::unique_ptr<HloProfilePrinter> profile_printer =
- CreateHloProfilePrinter(profile_index_map, cost_analysis);
+ std::unique_ptr<HloProfilePrinterData> profile_printer =
+ CreateHloProfilePrinterData(profile_index_map, cost_analysis);
HloExecutionProfile execution_profile(profile_printer.get(),
&profile_index_map);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 90121f7ffe..a889c35aeb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -404,6 +404,9 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::StringPiece outfeed_config) {
std::unique_ptr<HloInstruction> instruction =
WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
+ CHECK(ShapeUtil::Compatible(operand->shape(), shape))
+ << "Outfeed shape " << shape << " must be compatible with operand shape "
+ << operand->shape();
instruction->AppendOperand(operand);
instruction->outfeed_config_ = outfeed_config.ToString();
instruction->outfeed_shape_ = shape;
@@ -669,6 +672,58 @@ HloInstruction::CreateSelectAndScatter(
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateBroadcastSequence(
+ const Shape& output_shape, HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ adder) {
+ CHECK(ShapeUtil::IsScalar(operand->shape()) ||
+ ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
+ Shape broadcast_shape = ShapeUtil::ChangeElementType(
+ output_shape, operand->shape().element_type());
+ // Do explicit broadcast for scalar.
+ if (ShapeUtil::IsScalar(operand->shape())) {
+ auto broadcast =
+ HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
+ broadcast->set_metadata(operand->metadata());
+ if (operand->has_sharding()) {
+ broadcast->set_sharding(operand->sharding());
+ }
+ return broadcast;
+ }
+ // Do explicit broadcast for degenerate broadcast.
+ std::vector<int64> broadcast_dimensions;
+ std::vector<int64> reshaped_dimensions;
+ for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
+ if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
+ broadcast_dimensions.push_back(i);
+ reshaped_dimensions.push_back(operand->shape().dimensions(i));
+ } else {
+ CHECK_EQ(operand->shape().dimensions(i), 1)
+ << "An explicit broadcast sequence requires the broadcasted "
+ "dimensions to be trivial; operand: "
+ << operand->ToString() << "; output_shape: " << output_shape;
+ }
+ }
+ // Eliminate the size one dimensions.
+ HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(operand->shape().element_type(),
+ reshaped_dimensions),
+ operand));
+ reshaped_operand->set_metadata(operand->metadata());
+ if (operand->has_sharding()) {
+ reshaped_operand->set_sharding(operand->sharding());
+ }
+ // Broadcast 'reshape' up to the larger size.
+ auto broadcast = HloInstruction::CreateBroadcast(
+ broadcast_shape, reshaped_operand, broadcast_dimensions);
+ broadcast->set_metadata(operand->metadata());
+ if (operand->has_sharding()) {
+ broadcast->set_sharding(operand->sharding());
+ }
+ return broadcast;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e700ec1d29..5e89dc79be 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -409,6 +409,20 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ // Creates a sequence of instructions that performs an explicit broadcast of
+ // the operand to the target shape.
+ //
+ // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
+ // returned as a unique_ptr for API consistency with other factory methods in
+ // this interface.
+ //
+ // TODO(b/72173833) Ideally HloComputations would always be present, and so
+ // the adder being passed by the caller would not be necessary.
+ static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
+ const Shape& output_shape, HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ adder);
+
// Creates a pad instruction, where the operand is padded on the edges and
// between the elements with the given padding value.
static std::unique_ptr<HloInstruction> CreatePad(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 3af3b29ced..1038ab5555 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -712,8 +712,8 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
{1, 2},
{3, 4},
})));
- auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
- auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
+ auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
+ auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
auto outfeed10 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape10, constant, ""));
auto outfeed01 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 992f55788b..9206cdac05 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -83,6 +83,7 @@ HLO_MATCHER(Abs);
HLO_MATCHER(Add);
HLO_MATCHER(Bitcast);
HLO_MATCHER(Broadcast);
+HLO_MATCHER(BatchNormGrad);
HLO_MATCHER(Call);
HLO_MATCHER(Ceil);
HLO_MATCHER(Clamp);
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 58bb942211..99d8dd04e5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -523,7 +523,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
std::unordered_map<HloComputation*, HloComputation*> clone_map;
for (auto& computation : computations_) {
- auto cloned_computation = computation->Clone(suffix);
+ if (computation->IsFusionComputation()) {
+ // Cloning of a fused computation is handled by its fusion instruction.
+ continue;
+ }
+
+ // When cloning a computation, pass in the new module, so that for any
+ // fusion instruction in this computation, the fused computation will be
+ // deep cloned to the new module.
+ auto cloned_computation = computation->Clone(suffix, module.get());
InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
if (entry_computation_ == computation.get()) {
@@ -537,8 +545,15 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
for (auto* instruction : cloned_computation->instructions()) {
// Rewrite instruction's called_computation to point to the cloned
// computations.
- instruction->ReplaceCalledComputations(
- [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); });
+ instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
+ if (hlo->IsFusionComputation()) {
+ // Cloning of a fused computation has already been handled when its
+ // fusion instruction is cloned. So this hlo computation is already
+ // the cloned one.
+ return hlo;
+ }
+ return FindOrDie(clone_map, hlo);
+ });
}
}
return module;
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 0f5d3dccb7..cd51fa4e85 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -105,6 +105,48 @@ TEST_F(HloModuleTest, CloneTest) {
}
}
+TEST_F(HloModuleTest, CloneHasFusion) {
+ auto module = CreateNewModule();
+
+ // Create the fused computation.
+ HloComputation* fused_computation;
+ {
+ auto b = HloComputation::Builder("Fused");
+ auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ b.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
+ fused_computation = module->AddEmbeddedComputation(b.Build());
+ }
+
+ // Create the entry computation.
+ {
+ auto b = HloComputation::Builder("Entry");
+ auto input = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ b.AddInstruction(
+ HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
+ /*operands=*/{input}, fused_computation));
+ module->AddEntryComputation(b.Build());
+ }
+
+ auto post_order = module->MakeComputationPostOrder();
+ auto cloned_module = module->Clone("copy");
+ auto post_order_copied = cloned_module->MakeComputationPostOrder();
+
+ EXPECT_EQ(post_order.size(), post_order_copied.size());
+ for (auto origin = post_order.begin(), copied = post_order_copied.begin();
+ origin != post_order.end() && copied != post_order_copied.end();
+ ++origin, ++copied) {
+ if ((*origin)->name() == "Fused") {
+ // Clone of the fused computation is handled when its fusion instruction
+ // is cloned, which always use suffix ".clone".
+ EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
+ } else {
+ EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
+ }
+ }
+}
+
TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
// Create a module with a diamond call graph of computations.
auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 6f6e679a21..68e3c9618c 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -249,7 +249,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
string PredecessorHloOrdering::ToStringHelper(const string& name) const {
std::vector<string> pieces;
pieces.push_back(name);
- for (auto* computation : module_->computations()) {
+ for (auto* computation : module_->MakeNonfusionComputations()) {
pieces.push_back(tensorflow::strings::Printf("computation %s:",
computation->name().c_str()));
const auto all = computation->MakeInstructionPostOrder();
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 33bafd05c1..aba66114de 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -310,5 +311,56 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
*dataflow));
}
+// Regression test for HloOrdering::ToString() crashing when fed a computation
+// containing a fusion node.
+TEST_F(HloOrderingTest, ToStringDoesNotCrash) {
+ const char* module_str = R"(
+HloModule test_module
+
+body.v8 {
+ prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
+ get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.4, constant.1)
+ get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3
+ get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1
+ get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2
+ ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7)
+}
+
+condition.v4 {
+ constant.2 = s32[] constant(2)
+ prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
+ ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
+}
+
+fused_computation {
+ get-tuple-element.5.param_1 = f32[3]{0} parameter(1)
+ get-tuple-element.6.param_2 = f32[3]{0} parameter(2)
+ add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2)
+ get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0)
+ ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1)
+}
+
+ENTRY while.v11 {
+ constant.5 = s32[] constant(0)
+ constant.6 = f32[3]{0} constant({1, 1, 1})
+ constant.7 = f32[3]{0} constant({2, 2, 2})
+ constant.8 = f32[3]{0} constant({3, 3, 3})
+ tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8)
+ while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8
+ get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3
+ get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1
+ get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2
+ ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(module_str));
+ DependencyHloOrdering ordering(module.get());
+ ordering.ToString(); // Shouldn't crash.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc
index e944ad1513..dcc2279301 100644
--- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc
+++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc
@@ -18,20 +18,20 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
namespace xla {
-string HloProfilePrinter::ToString(const int64* counters,
- double clock_rate_ghz) const {
+string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
+ const int64* counters, double clock_rate_ghz) {
+ using HloComputationInfo = HloProfilePrinterData::HloComputationInfo;
+ using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo;
+
string result;
- for (int computation_idx = 0; computation_idx < computation_infos_size_;
- computation_idx++) {
- const HloComputationInfo& computation = computation_infos_[computation_idx];
- const HloInstructionInfo* instructions_begin = computation.instructions;
- const HloInstructionInfo* instructions_end =
- computation.instructions + computation.instructions_size;
+ for (const HloComputationInfo& computation_info :
+ hlo_profile_printer_data.computation_infos()) {
+ const auto& instruction_infos = computation_info.instruction_infos();
bool any_instruction_profiled =
- std::any_of(instructions_begin, instructions_end,
+ std::any_of(instruction_infos.begin(), instruction_infos.end(),
[&](const HloInstructionInfo& instruction_info) {
- return counters[instruction_info.profile_index] != 0;
+ return counters[instruction_info.profile_index()] != 0;
});
if (!any_instruction_profiled) {
@@ -41,16 +41,19 @@ string HloProfilePrinter::ToString(const int64* counters,
// Once we start using this in AOT for real, we will probably need a more
// minimal version of HumanReadableProfileBuilder.
HumanReadableProfileBuilder builder(
- computation.name, counters[computation.profile_index], clock_rate_ghz);
+ computation_info.name(), counters[computation_info.profile_index()],
+ clock_rate_ghz);
- for (const auto* instruction = instructions_begin;
- instruction != instructions_end; instruction++) {
+ for (const auto& instruction_info : instruction_infos) {
builder.AddOp(
- /*op_name=*/instruction->long_name,
- /*short_name=*/instruction->short_name, instruction->category,
- counters[instruction->profile_index], instruction->flop_count,
- instruction->transcendental_count, instruction->bytes_accessed,
- instruction->optimal_seconds);
+ /*op_name=*/instruction_info.long_name(),
+ /*short_name=*/instruction_info.short_name(),
+ instruction_info.category(),
+ counters[instruction_info.profile_index()],
+ instruction_info.flop_count(),
+ instruction_info.transcendental_count(),
+ instruction_info.bytes_accessed(),
+ instruction_info.optimal_seconds());
}
result += builder.ToString();
@@ -58,10 +61,4 @@ string HloProfilePrinter::ToString(const int64* counters,
return result;
}
-
-HloProfilePrinter::~HloProfilePrinter() {
- if (deleter_) {
- deleter_(computation_infos_, computation_infos_size_);
- }
-}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h
index 2f056490ae..b72325c755 100644
--- a/tensorflow/compiler/xla/service/hlo_profile_printer.h
+++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h
@@ -13,91 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
#include <cstdint>
#include <string>
#include <vector>
+#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
-// Instances of this class can pretty-print profile counters gathered from
-// running an XLA computation without having access to the backing module.
-class HloProfilePrinter {
- public:
- // Holds meta information about an HloInstruction.
- //
- // The pointer-typed fields can be owning or non-owning -- this decision is
- // manifested as the deleter_ function in the containing HloProfilePrinter.
- struct HloInstructionInfo {
- // Textual information for pretty printing.
- const char* long_name;
- const char* short_name;
- const char* category;
-
- // Metrics computed by HloCostAnalysis.
- float flop_count;
- float transcendental_count;
- float bytes_accessed;
- float optimal_seconds;
-
- // The index into the profile counters array for the HloInstruction
- // corresponding to this HloInstructionInfo.
- int64 profile_index;
- };
-
- // Holds meta information about an HloComputation.
- //
- // The pointer-typed fields can be owning or non-owning -- this decision is
- // manifested as the deleter_ function in the containing HloProfilePrinter.
- struct HloComputationInfo {
- const char* name;
-
- // The index into the profile counters array for the HloInstruction
- // corresponding to this HloComputationInfo.
- int64 profile_index;
-
- HloInstructionInfo* instructions;
- int64 instructions_size;
- };
-
- HloProfilePrinter(
- HloComputationInfo* computation_infos, int64 computation_infos_size,
- int64 profile_counters_size,
- std::function<void(HloComputationInfo*, int64)> deleter = nullptr)
- : computation_infos_(computation_infos),
- computation_infos_size_(computation_infos_size),
- profile_counters_size_(profile_counters_size),
- deleter_(std::move(deleter)) {}
-
- HloProfilePrinter(HloProfilePrinter&& other) {
- std::swap(other.computation_infos_, computation_infos_);
- std::swap(other.computation_infos_size_, computation_infos_size_);
- std::swap(other.deleter_, deleter_);
- }
-
- HloProfilePrinter(const HloProfilePrinter&) = delete;
- HloProfilePrinter& operator=(const HloProfilePrinter&) = delete;
-
- // Converts the profile counter sequence `counters` to a human readable string
- // representation.
- string ToString(const int64* counters, double clock_rate_ghz) const;
-
- // Returns the size of the profile buffer expected by this printer.
- int64 profile_counters_size() const { return profile_counters_size_; }
-
- ~HloProfilePrinter();
-
- private:
- // The `computation_infos_` field can be owning or non-owning -- this decision
- // is manifested as the deleter_ function.
- HloComputationInfo* computation_infos_ = nullptr;
- int64 computation_infos_size_ = 0;
- int64 profile_counters_size_ = 0;
- std::function<void(HloComputationInfo*, int64)> deleter_;
-};
+// Pretty-print an array of profile counters using hlo_profile_printer_data.
+string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
+ const int64* counters, double clock_rate_ghz);
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto
new file mode 100644
index 0000000000..9f22b733fe
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package xla;
+
+option cc_enable_arenas = true;
+
+// Describes how to pretty-print a profile counter array gathered for a specific
+// HloModule.
+message HloProfilePrinterData {
+ // Pretty-printer information about an HloInstruction.
+ message HloInstructionInfo {
+ string long_name = 1;
+ string short_name = 2;
+ string category = 3;
+
+ // Metrics computed by HloCostAnalysis.
+ float flop_count = 4;
+ float transcendental_count = 5;
+ float bytes_accessed = 6;
+ float optimal_seconds = 7;
+
+ // The index into the profile counters array for the HloInstruction
+ // corresponding to this HloInstructionInfo.
+ int64 profile_index = 8;
+ }
+
+ // Pretty-printer information about an HloComputation.
+ message HloComputationInfo {
+ string name = 1;
+
+ // The index into the profile counters array for the HloComputation
+ // corresponding to this HloComputationInfo.
+ int64 profile_index = 2;
+
+ // HloInstructionInfos for every HloInstruction in the HloComputation for
+ // corresponding to this HloComputattionInfo.
+ repeated HloInstructionInfo instruction_infos = 3;
+ }
+
+ // HloComputationInfos for every HloComputation in the HloModule.
+ repeated HloComputationInfo computation_infos = 1;
+
+ // The size of the profile counters array we will pretty-print.
+ int64 profile_counters_size = 2;
+}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
index 9aa3e501d5..c4876b852e 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/xla.pb.h"
@@ -56,4 +56,4 @@ class HloTfGraphBuilder {
} // namespace hlo_graph_dumper
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9d9cf0c0f6..6e46f945e0 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -107,8 +107,20 @@ Status ShapeVerifier::HandleInfeed(HloInstruction*) {
return tensorflow::Status::OK();
}
-Status ShapeVerifier::HandleOutfeed(HloInstruction*) {
- return tensorflow::Status::OK();
+Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
+ // Outfeed has a separate shape field for the value which is outfed to the
+ // host. The shape of the instruction itself is always nil because the outfeed
+ // produces no HLO value in the graph.
+ if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
+ outfeed->operand(0)->shape())) {
+ return InvalidArgument(
+ "Expected outfeed to have shape compatible with operand's shape %s, "
+ "actual shape is %s:\n%s",
+ ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
+ ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
+ outfeed->ToString().c_str());
+ }
+ return CheckShape(outfeed, ShapeUtil::MakeNil());
}
Status ShapeVerifier::HandleRng(HloInstruction*) {
@@ -159,7 +171,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension));
+ operand_shape.dimensions(operand_dimension))
+ << broadcast->ToString() << " operand shape " << operand_shape;
}
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 6368611f32..5a1d864e03 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -127,4 +127,4 @@ class HloVerifier : public HloPassInterface {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index f80dace877..bbea6bee56 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -81,7 +81,10 @@ OperandLayoutConstraint::OperandLayoutConstraint(
operand_no_(operand_no) {
CHECK(shape_layout_.LayoutIsSet());
CHECK(ShapeUtil::Compatible(shape_layout.shape(),
- instruction->operand(operand_no)->shape()));
+ instruction->operand(operand_no)->shape()))
+ << shape_layout.shape() << " is not compatible with "
+ << instruction->operand(operand_no)->shape() << " (for operand "
+ << operand_no << " of instruction " << instruction->ToString() << ")";
}
string OperandLayoutConstraint::ToString() const {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 827e092a3f..1c00b2aabd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
#include <string>
@@ -179,4 +179,4 @@ class KernelSupportLibrary {
};
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index d2bcb38d09..8d1e6338e1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -150,6 +150,8 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
// addition to an addition on this type (int16) - this is just the type
// used for storage.
return llvm::Type::getInt16Ty(module->getContext());
+ case F16:
+ return llvm::Type::getHalfTy(module->getContext());
case S32:
case U32:
return llvm::Type::getInt32Ty(module->getContext());
@@ -292,6 +294,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
ir_element_type,
tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
break;
+ case F16:
+ value = llvm::ConstantFP::get(
+ ir_element_type,
+ static_cast<float>(literal.Get<half>(*multi_index)));
+ break;
case F64:
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<double>(*multi_index));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h
index f72f482e31..175b081e84 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
@@ -90,4 +90,4 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace(
} // namespace llvm_ir
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 2194d24257..f30530db08 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -128,7 +128,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, argument_layouts, &execution_options));
+ CreateModuleConfig(*program_shape, argument_layouts, &execution_options,
+ *user_computation));
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
execute_backend_->stream_executor(device_ordinal));
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index 598d08b720..f4c63dd86b 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -90,4 +90,4 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index fc848bdb03..849df1d8e6 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -34,8 +34,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -55,6 +57,7 @@ namespace se = ::perftools::gputools;
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrCat;
+using ::xla::source_map_util::InvalidParameterArgument;
namespace xla {
@@ -260,7 +263,8 @@ StatusOr<std::vector<const ShapedBuffer*>> Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options) {
+ const ExecutionOptions* execution_options,
+ const UserComputation& user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = config->mutable_entry_computation_layout();
@@ -274,8 +278,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
- return InvalidArgument(
- "computation expects parameter %d to have shape %s, given shape %s",
+ return InvalidParameterArgument(
+ *user_computation.ParameterMetadata(i).value(),
+ "Argument does not match shape of computation parameter %d: want %s, "
+ "got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
@@ -317,12 +323,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options) {
+ const ExecutionOptions& execution_options,
+ const UserComputation& user_computation) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape());
}
- return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
+ return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
+ user_computation);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
@@ -419,6 +427,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
/*include_unreachable_instructions=*/
true));
+ TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor));
@@ -566,7 +576,7 @@ Service::ExecuteParallelAndRegisterResult(
se::Stream* stream = index_to_profiled_stream.second;
Executable* executable = executables[device];
const HloModule& module = executable->module();
- HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(),
+ HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
@@ -739,9 +749,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
// Create an HloModuleConfig object for the computation, given the shape of
// the program and the argument allocations.
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments,
- request.execution_options()));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arguments,
+ request.execution_options(), *user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -849,7 +860,8 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -913,7 +925,8 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -1233,7 +1246,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(program_shape, {}, execution_options));
+ CreateModuleConfig(program_shape, {}, execution_options,
+ *user_computation));
// Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
@@ -1597,4 +1611,15 @@ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Service::Replicas(
return replicas;
}
+Status Service::MaybeDumpHloModule(const HloModule& module) const {
+ const string xla_dump_prepass_hlo_proto_to =
+ module.config().debug_options().xla_dump_prepass_hlo_proto_to();
+ if (xla_dump_prepass_hlo_proto_to.empty()) {
+ return Status::OK();
+ }
+ HloProto proto = MakeHloProto(module);
+ return protobuf_util::DumpProtoToDirectory(
+ proto, xla_dump_prepass_hlo_proto_to, module.name());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index f962d0cdc7..ca77e8fe3a 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -251,7 +251,8 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options);
+ const ExecutionOptions& execution_options,
+ const UserComputation& user_computation);
protected:
friend class LocalExecutable;
@@ -275,7 +276,8 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options);
+ const ExecutionOptions* execution_options,
+ const UserComputation& user_computation);
// Builds an Executable for the given parameters.
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
@@ -340,6 +342,8 @@ class Service : public ServiceInterface {
StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
+ Status MaybeDumpHloModule(const HloModule& module) const;
+
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
DeviceHandle SingleComputationDeviceHandle() const;
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
new file mode 100644
index 0000000000..8cbaac7b37
--- /dev/null
+++ b/tensorflow/compiler/xla/service/source_map_util.cc
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/source_map_util.h"
+
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace source_map_util {
+namespace {
+
+Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
+ const char* format, va_list args) {
+ string message;
+ tensorflow::strings::Appendv(&message, format, args);
+ if (!op_metadata.source_file().empty()) {
+ tensorflow::strings::Appendf(&message, " (%s:%d)",
+ op_metadata.source_file().c_str(),
+ op_metadata.source_line());
+ }
+ return InvalidArgument("%s", message.c_str());
+}
+
+} // namespace
+
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ Status result = InvalidParameterArgumentV(op_metadata, format, args);
+ va_end(args);
+ return result;
+}
+
+Status InvalidParameterArgument(Executable* executable, int parameter_number,
+ const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ if (executable != nullptr && executable->has_module()) {
+ const HloModule& module = executable->module();
+ const HloComputation& computation = *module.entry_computation();
+ HloInstruction* param = computation.parameter_instruction(parameter_number);
+ const OpMetadata& metadata = param->metadata();
+ Status result = InvalidParameterArgumentV(metadata, format, args);
+ va_end(args);
+ return result;
+ }
+ Status result = InvalidArgumentV(format, args);
+ va_end(args);
+ return result;
+}
+
+} // namespace source_map_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h
new file mode 100644
index 0000000000..a776d745f4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/source_map_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace source_map_util {
+
+// Creates an INVALID_ARUGMENT status with the given format string.
+//
+// Also, attempts to extract the OpMetadata for parameter_number on executable
+// and append it to the status message for source mapping to user code.
+//
+// executable may be nullptr, but parameter_number should not be out of bounds
+// or a CHECK-failure may occur.
+Status InvalidParameterArgument(Executable* executable, int parameter_number,
+ const char* format, ...)
+ TF_PRINTF_ATTRIBUTE(3, 4);
+
+// As above, but takes the parameter metadata directly instead of extracting it
+// from the executable.
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const char* format, ...)
+ TF_PRINTF_ATTRIBUTE(2, 3);
+
+} // namespace source_map_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 7882b70ab7..2ea6507900 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -2767,48 +2767,11 @@ HloComputation* ComputationLowerer::ResolveComputation(
HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
HloInstruction* operand, const Shape& output_shape) {
- CHECK(ShapeUtil::IsScalar(operand->shape()) ||
- ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
- Shape broadcast_shape = ShapeUtil::MakeShape(
- operand->shape().element_type(), AsInt64Slice(output_shape.dimensions()));
- // Do explicit broadcast for scalar.
- if (ShapeUtil::IsScalar(operand->shape())) {
- HloInstruction* broadcast = hlo_builder_.AddInstruction(
- HloInstruction::CreateBroadcast(broadcast_shape, operand, {}));
- broadcast->set_metadata(operand->metadata());
- if (operand->has_sharding()) {
- broadcast->set_sharding(operand->sharding());
- }
- return broadcast;
- }
- // Do explicit broadcast for degenerate broadcast.
- std::vector<int64> broadcast_dimensions;
- std::vector<int64> reshaped_dimensions;
- for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
- if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
- broadcast_dimensions.push_back(i);
- reshaped_dimensions.push_back(operand->shape().dimensions(i));
- }
- }
- // Eliminate the size one dimensions.
- HloInstruction* reshaped_operand =
- hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(operand->shape().element_type(),
- reshaped_dimensions),
- operand));
- reshaped_operand->set_metadata(operand->metadata());
- if (operand->has_sharding()) {
- reshaped_operand->set_sharding(operand->sharding());
- }
- // Broadcast 'reshape' up to the larger size.
- HloInstruction* broadcast =
- hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, reshaped_operand, broadcast_dimensions));
- broadcast->set_metadata(operand->metadata());
- if (operand->has_sharding()) {
- broadcast->set_sharding(operand->sharding());
- }
- return broadcast;
+ auto fadd = [this](std::unique_ptr<HloInstruction> x) {
+ return hlo_builder_.AddInstruction(std::move(x));
+ };
+ return fadd(
+ HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd));
}
void ComputationLowerer::Visit(
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 50dac32a4a..d3d55634c9 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -41,4 +41,4 @@ class WhileLoopSimplifier : public HloPassInterface {
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index 63afab4206..063e312df6 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -29,4 +29,4 @@ class ZeroSizedHloElimination : public HloPassInterface {
}
};
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_
diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h
index 903fee5255..f2ce22d672 100644
--- a/tensorflow/compiler/xla/sparse_index_array.h
+++ b/tensorflow/compiler/xla/sparse_index_array.h
@@ -15,8 +15,8 @@ limitations under the License.
// Utility class for managing sparse array indices.
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
+#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
#include <vector>
@@ -173,4 +173,4 @@ void SparseIndexArray::SortWithValues(
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
+#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/compiler/xla/statusor_internals.h
index a2fda5bb3c..14636bd144 100644
--- a/tensorflow/compiler/xla/statusor_internals.h
+++ b/tensorflow/compiler/xla/statusor_internals.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
+#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
+#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/core/platform/macros.h"
@@ -242,4 +242,4 @@ struct TraitsBase<false, false> {
} // namespace internal_statusor
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
+#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 3922c779a0..4410647f84 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -351,6 +351,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:platform_util",
@@ -815,9 +816,6 @@ xla_test(
xla_test(
name = "bfloat16_test",
srcs = ["bfloat16_test.cc"],
- blacklisted_backends = [
- "gpu",
- ],
shard_count = 40,
deps = [
":test_utils",
@@ -848,6 +846,31 @@ xla_test(
)
xla_test(
+ name = "half_test",
+ srcs = ["half_test.cc"],
+ backends = [
+ # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25
+ # "cpu",
+ "gpu",
+ ],
+ deps = [
+ ":test_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "slice_test",
srcs = ["slice_test.cc"],
shard_count = 40,
@@ -1013,6 +1036,10 @@ xla_test(
name = "select_and_scatter_test",
timeout = "long",
srcs = ["select_and_scatter_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ "optonly",
+ ],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
@@ -1052,6 +1079,19 @@ xla_test(
)
xla_test(
+ name = "reduce_hlo_test",
+ srcs = ["reduce_hlo_test.cc"],
+ deps = [
+ ":client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "call_test",
srcs = ["call_test.cc"],
deps = [
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index e47fcad475..b853dfaa15 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -99,8 +99,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
auto expected = Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
- {{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
- {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
+ {{{{static_cast<bfloat16>(-1.6875f)},
+ {static_cast<bfloat16>(-2.04f)}},
+ {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 659660d91e..f594cc10ac 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -104,7 +104,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 0"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 0"));
// Shape mismatch in parameter 1 (rank)
status = client_->Execute(computation, {f32_data.get(), f32_data.get()},
@@ -112,7 +113,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 1"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 1"));
// Shape mismatch in parameter 1 (element type)
status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()},
@@ -120,7 +122,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 1"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 1"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 7c9494f133..a677986cd9 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -387,7 +387,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqualTuple(expected, *actual);
+ LiteralTestUtil::ExpectEqual(expected, *actual);
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -399,7 +399,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
+ LiteralTestUtil::ExpectNear(expected, *actual, error);
}
void ClientLibraryTestBase::ComputeAndCompare(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index a559a653df..ba0319990b 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -431,6 +431,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -456,6 +457,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -481,6 +483,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -506,6 +509,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -531,6 +535,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 0016b6cc61..bc82167482 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -355,8 +355,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
}
// Test true and false computations that return a tuple of arrays.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
@@ -373,9 +372,7 @@ XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest,
- DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -413,8 +410,7 @@ XLA_TEST_F(ConditionalOpTest,
}
// Test true and false computations that return a nested tuple.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) {
+XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -532,6 +528,32 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
+XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
+ ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
+ auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
+ auto pred_cond = inner_builder.GetTupleElement(param0, 0);
+ auto true_operand = inner_builder.GetTupleElement(param0, 1);
+ auto false_operand = inner_builder.GetTupleElement(param0, 2);
+ inner_builder.Conditional(pred_cond, true_operand,
+ CreateR0CeilComputation(), false_operand,
+ CreateR0FloorComputation());
+ }
+ auto inner_builder_result = inner_builder.Build();
+ EXPECT_IS_OK(inner_builder_result.status());
+
+ ComputationBuilder builder(client_, TestName());
+ auto pred2 = builder.ConstantR0<bool>(false);
+ auto operand1 = builder.ConstantR0<float>(1.1f);
+ auto operand2 = builder.ConstantR0<float>(12.2f);
+ auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
+ builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
+
+ ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+}
+
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index a10e17dbf3..0ceb9aff37 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -608,5 +608,28 @@ INSTANTIATE_TEST_CASE_P(
);
+TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
+ ComputationBuilder builder(client_, TestName());
+ Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
+ Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<bfloat16> input_data(1, 1, 1, 2);
+ input_data.FillWithYX(Array2D<bfloat16>({
+ {bfloat16(1), bfloat16(2)},
+ }));
+ Array4D<bfloat16> filter_data(1, 1, 1, 2);
+ filter_data.FillWithYX(Array2D<bfloat16>({
+ {bfloat16(5), bfloat16(6)},
+ }));
+
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
+ error_spec_);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index ae3f887240..877dc7db0e 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -595,6 +595,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
// Single element, no wrap.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) {
+ // Single element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}
@@ -602,6 +607,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
// Multiple element, no wrap.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/2);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) {
+ // Multiple element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/2);
}
@@ -609,6 +619,11 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) {
// Multiple element, wrapping.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/3, /*size=*/2);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) {
+ // Multiple element, wrapping.
+ std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/3, /*size=*/2);
}
@@ -616,12 +631,21 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) {
// Multiple element, update size larger than operand.
std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<float>(operand_shape, /*index=*/5, /*size=*/2);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) {
+ // Multiple element, update size larger than operand.
+ std::vector<int32> operand_shape({4, 5, 2});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/5, /*size=*/2);
}
XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
std::vector<int32> operand_shape({3, 123, 247});
RunR3Contiguous<float>(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) {
+ std::vector<int32> operand_shape({3, 123, 247});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/1, /*size=*/1);
}
@@ -629,6 +653,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
std::vector<int32> operand_shape({32, 128, 1024});
RunR3Contiguous<float>(operand_shape, /*index=*/7, /*size=*/1);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) {
+ std::vector<int32> operand_shape({32, 128, 1024});
RunR3Contiguous<bfloat16>(operand_shape, /*index=*/7, /*size=*/1);
}
diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h
index 493ff7414b..3830d5a44d 100644
--- a/tensorflow/compiler/xla/tests/filecheck.h
+++ b/tensorflow/compiler/xla/tests/filecheck.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
#include <string>
@@ -30,4 +30,4 @@ StatusOr<bool> RunFileCheck(const string& input, const string& pattern);
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
new file mode 100644
index 0000000000..ec2f49d43b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -0,0 +1,257 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+// Tests the handling of the basic mathematics operations with F16 operands.
+
+namespace xla {
+namespace {
+
+class HalfTestBase : public ClientLibraryTestBase {
+ protected:
+ const ErrorSpec error_spec_{0.001, 0.001};
+ // Number of elements in the input buffers.
+ static const int kNumElements = 4;
+};
+
+using UnaryBuildFuncTy =
+ std::function<void(ComputationBuilder*, const ComputationDataHandle& src)>;
+
+struct UnaryOpTestParam {
+ std::function<half(half)> compute_func;
+ UnaryBuildFuncTy build_func;
+};
+
+class UnaryOpTest : public HalfTestBase,
+ public ::testing::WithParamInterface<UnaryOpTestParam> {};
+
+XLA_TEST_P(UnaryOpTest, Ops) {
+ std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle x_opnd;
+ auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
+ &builder, &x_opnd);
+
+ std::function<half(half)> compute_func = GetParam().compute_func;
+ std::vector<half> expected;
+ for (int64 i = 0; i < x.size(); ++i) {
+ expected.push_back(compute_func(x[i]));
+ }
+
+ UnaryBuildFuncTy build_func = GetParam().build_func;
+ build_func(&builder, x_opnd);
+
+ ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
+}
+
+half sign_imp(half value) {
+ const float x(std::move(value));
+ return half((x < .0) ? -1 : (x > .0));
+}
+
+half round_imp(half value) {
+ return half(round(static_cast<float>(std::move(value))));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ half, UnaryOpTest,
+ ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); },
+ &ComputationBuilder::Abs},
+ UnaryOpTestParam{[](half x) { return round_imp(x); },
+ &ComputationBuilder::Round},
+ UnaryOpTestParam{[](half x) { return ceil(x); },
+ &ComputationBuilder::Ceil},
+ UnaryOpTestParam{[](half x) { return cos(x); },
+ &ComputationBuilder::Cos},
+ UnaryOpTestParam{[](half x) { return exp(x); },
+ &ComputationBuilder::Exp},
+ UnaryOpTestParam{[](half x) { return floor(x); },
+ &ComputationBuilder::Floor},
+ UnaryOpTestParam{[](half x) { return log(x); },
+ &ComputationBuilder::Log},
+ UnaryOpTestParam{[](half x) { return -x; },
+ &ComputationBuilder::Neg},
+ UnaryOpTestParam{[](half x) { return sign_imp(x); },
+ &ComputationBuilder::Sign},
+ UnaryOpTestParam{[](half x) { return sin(x); },
+ &ComputationBuilder::Sin},
+ UnaryOpTestParam{[](half x) { return tanh(x); },
+ &ComputationBuilder::Tanh}
+
+ ));
+
+struct UnaryPredTestParam {
+ std::function<bool(half)> compute_func;
+ UnaryBuildFuncTy build_func;
+};
+
+class UnaryPredTest : public HalfTestBase,
+ public ::testing::WithParamInterface<UnaryPredTestParam> {
+};
+
+XLA_TEST_P(UnaryPredTest, Ops) {
+ std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle x_opnd;
+ auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
+ &builder, &x_opnd);
+
+ std::function<bool(half)> compute_func = GetParam().compute_func;
+ CHECK_EQ(kNumElements, x.size());
+ bool expected[kNumElements];
+ for (int64 i = 0; i < x.size(); ++i) {
+ expected[i] = compute_func(x[i]);
+ }
+
+ UnaryBuildFuncTy build_func = GetParam().build_func;
+ build_func(&builder, x_opnd);
+
+ ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
+}
+
+INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
+ ::testing::Values(UnaryPredTestParam{
+ [](half x) { return isfinite(x); },
+ &ComputationBuilder::IsFinite}));
+
+using BinaryBuildFuncTy = std::function<void(
+ ComputationBuilder*, const ComputationDataHandle& x,
+ const ComputationDataHandle& y, tensorflow::gtl::ArraySlice<int64>)>;
+
+struct BinaryOpTestParam {
+ std::function<half(half, half)> compute_func;
+ BinaryBuildFuncTy build_func;
+};
+
+class BinaryOpTest : public HalfTestBase,
+ public ::testing::WithParamInterface<BinaryOpTestParam> {};
+
+XLA_TEST_P(BinaryOpTest, Ops) {
+ std::vector<half> x({half(1.0), half(2.0), half(3.0), half(-4.0)});
+ std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle x_opnd;
+ auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
+ &builder, &x_opnd);
+
+ ComputationDataHandle y_opnd;
+ auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
+ &builder, &y_opnd);
+
+ std::function<half(half, half)> compute_func = GetParam().compute_func;
+ std::vector<half> expected;
+ for (int64 i = 0; i < x.size(); ++i) {
+ expected.push_back(compute_func(x[i], y[i]));
+ }
+
+ BinaryBuildFuncTy build_func = GetParam().build_func;
+ build_func(&builder, x_opnd, y_opnd, {});
+
+ ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
+ error_spec_);
+}
+
+half atan2_imp(half x, half y) {
+ return half(atan2(static_cast<float>(std::move(x)),
+ static_cast<float>(std::move(y))));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ half, BinaryOpTest,
+ ::testing::Values(
+ BinaryOpTestParam{[](half x, half y) { return x + y; },
+ &ComputationBuilder::Add},
+ BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
+ &ComputationBuilder::Atan2},
+ BinaryOpTestParam{[](half x, half y) { return x / y; },
+ &ComputationBuilder::Div},
+ BinaryOpTestParam{[](half x, half y) { return max(x, y); },
+ &ComputationBuilder::Max},
+ BinaryOpTestParam{[](half x, half y) { return min(x, y); },
+ &ComputationBuilder::Min},
+ BinaryOpTestParam{[](half x, half y) { return x * y; },
+ &ComputationBuilder::Mul},
+ BinaryOpTestParam{[](half x, half y) { return pow(x, y); },
+ &ComputationBuilder::Pow},
+ BinaryOpTestParam{[](half x, half y) { return x - y; },
+ &ComputationBuilder::Sub}
+
+ ));
+
+struct BinaryPredTestParam {
+ std::function<bool(half, half)> compute_func;
+ BinaryBuildFuncTy build_func;
+};
+
+class BinaryPredTest
+ : public HalfTestBase,
+ public ::testing::WithParamInterface<BinaryPredTestParam> {};
+
+XLA_TEST_P(BinaryPredTest, Ops) {
+ std::vector<half> x({half(1.0), half(2.0), half(0.2), half(-4.0)});
+ std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle x_opnd;
+ auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
+ &builder, &x_opnd);
+
+ ComputationDataHandle y_opnd;
+ auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
+ &builder, &y_opnd);
+
+ std::function<bool(half, half)> compute_func = GetParam().compute_func;
+ CHECK_EQ(kNumElements, x.size());
+ bool expected[kNumElements];
+ for (int64 i = 0; i < x.size(); ++i) {
+ expected[i] = compute_func(x[i], y[i]);
+ }
+
+ BinaryBuildFuncTy build_func = GetParam().build_func;
+ build_func(&builder, x_opnd, y_opnd, {});
+
+ ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
+}
+
+INSTANTIATE_TEST_CASE_P(
+ half, BinaryPredTest,
+ ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; },
+ &ComputationBuilder::Eq},
+ BinaryPredTestParam{[](half x, half y) { return x != y; },
+ &ComputationBuilder::Ne},
+ BinaryPredTestParam{[](half x, half y) { return x >= y; },
+ &ComputationBuilder::Ge},
+ BinaryPredTestParam{[](half x, half y) { return x > y; },
+ &ComputationBuilder::Gt},
+ BinaryPredTestParam{[](half x, half y) { return x <= y; },
+ &ComputationBuilder::Le},
+ BinaryPredTestParam{[](half x, half y) { return x < y; },
+ &ComputationBuilder::Lt}
+
+ ));
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index e5b96c51ce..f8205de702 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -301,6 +301,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case BF16:
match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
break;
+ case F16:
+ match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
+ break;
case F32:
match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
break;
@@ -313,6 +316,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case TUPLE: {
bool tuple_match = true;
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ SCOPED_TRACE(tensorflow::strings::StrCat(
+ "Tuple index ", i, " in ",
+ ShapeUtil::HumanString(expected.shape())));
+
// Create LiteralViews of the expected and actual elements.
auto result = Equal(LiteralView::Create(expected, {i}),
LiteralView::Create(actual, {i}));
@@ -336,47 +343,6 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
return result;
}
-/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple(
- const Literal& expected, const Literal& actual) {
- VLOG(1) << "expected: " << expected.ToString();
- VLOG(1) << "actual: " << actual.ToString();
-
- if (!ShapeUtil::IsTuple(expected.shape()) ||
- !ShapeUtil::IsTuple(actual.shape())) {
- return ::testing::AssertionFailure()
- << "tuples expected shape = " << expected.shape().ShortDebugString()
- << " actual shape = " << actual.shape().ShortDebugString();
- }
- AssertEqualShapes(expected.shape(), actual.shape());
-
- ::testing::AssertionResult err = ::testing::AssertionSuccess();
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
- "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
- const auto expected_element = LiteralView::Create(expected, {i});
- const auto actual_element = LiteralView::Create(actual, {i});
-
- ::testing::AssertionResult res = [&] {
- if (ShapeUtil::IsTuple(expected_element.shape())) {
- return EqualTuple(expected_element, actual_element);
- } else {
- return Equal(expected_element, actual_element);
- }
- }();
-
- if (!res && err) {
- err = res;
- }
- }
-
- return err;
-}
-
-/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
- const Literal& actual) {
- EXPECT_TRUE(EqualTuple(expected, actual));
-}
-
namespace {
// Helper class for comparing floating-point literals within an error bound.
@@ -417,6 +383,9 @@ class NearComparator {
case BF16:
ExpectLiteralsNear<bfloat16>(expected, actual, 0);
break;
+ case F16:
+ ExpectLiteralsNear<half>(expected, actual, 0);
+ break;
case F32:
ExpectLiteralsNear<float>(expected, actual, 0);
break;
@@ -609,14 +578,47 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
static_cast<float>(actual));
}
+template <>
+bool NearComparator::ExpectValuesNear<half>(half expected, half actual) {
+ return ExpectValuesNear(static_cast<float>(std::move(expected)),
+ static_cast<float>(std::move(actual)));
+}
+
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
const Literal& expected, const Literal& actual, const ErrorSpec& error) {
- NearComparator comparator(error);
- return comparator.ExpectNear(expected, actual)
- ? ::testing::AssertionSuccess()
- : ::testing::AssertionFailure() << "values were not near";
+ ::testing::AssertionResult err =
+ EqualShapes(expected.shape(), actual.shape());
+ if (!err) {
+ return err;
+ }
+
+ if (ShapeUtil::IsTuple(expected.shape())) {
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ SCOPED_TRACE(tensorflow::strings::StrCat(
+ "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
+ const auto expected_element = LiteralView::Create(expected, {i});
+ const auto actual_element = LiteralView::Create(actual, {i});
+
+ ::testing::AssertionResult res =
+ Near(expected_element, actual_element, error);
+ if (err && !res) {
+ err = res;
+ }
+ }
+ return err;
+ }
+
+ if (ShapeUtil::ElementIsFloating(expected.shape()) ||
+ ShapeUtil::ElementIsComplex(expected.shape())) {
+ NearComparator comparator(error);
+ return comparator.ExpectNear(expected, actual)
+ ? ::testing::AssertionSuccess()
+ : ::testing::AssertionFailure() << "values were not near";
+ }
+
+ return Equal(expected, actual);
}
/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
@@ -629,65 +631,13 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
: tensorflow::strings::StrCat("\nmessage: ", message));
}
-/* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple(
- const Literal& expected, const Literal& actual, const ErrorSpec& error) {
- VLOG(1) << "expected: " << expected.ToString();
- VLOG(1) << "actual: " << actual.ToString();
-
- if (!ShapeUtil::IsTuple(expected.shape()) ||
- !ShapeUtil::IsTuple(actual.shape())) {
- return ::testing::AssertionFailure()
- << "tuples expected shape = " << expected.shape().ShortDebugString()
- << " actual shape = " << actual.shape().ShortDebugString();
- }
- AssertEqualShapes(expected.shape(), actual.shape());
-
- ::testing::AssertionResult err = ::testing::AssertionSuccess();
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- SCOPED_TRACE(tensorflow::strings::StrCat(
- "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
- const auto expected_element = LiteralView::Create(expected, {i});
- const auto actual_element = LiteralView::Create(actual, {i});
-
- ::testing::AssertionResult res = [&] {
- if (ShapeUtil::IsTuple(expected_element.shape())) {
- return NearTuple(expected_element, actual_element, error);
- } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) {
- return Near(expected_element, actual_element, error);
- } else {
- return Equal(expected_element, actual_element);
- }
- }();
-
- if (err && !res) {
- err = res;
- }
- }
- return err;
-}
-
-/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected,
- const Literal& actual,
- const ErrorSpec& error) {
- EXPECT_TRUE(NearTuple(expected, actual, error));
-}
-
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
const Literal& expected, const Literal& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
- bool is_tuple = ShapeUtil::IsTuple(expected.shape());
if (error.has_value()) {
- if (is_tuple) {
- VLOG(1) << "Expects near tuple";
- return NearTuple(expected, actual, *error);
- }
VLOG(1) << "Expects near";
return Near(expected, actual, *error);
}
- if (is_tuple) {
- VLOG(1) << "Expects equal tuple";
- return EqualTuple(expected, actual);
- }
VLOG(1) << "Expects equal";
return Equal(expected, actual);
}
@@ -712,6 +662,7 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
new_num_elements *= new_dimensions[i];
}
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+ CHECK_EQ(new_dimensions.size(), minor_to_major.size());
auto new_literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
@@ -761,6 +712,10 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
new_literal->Set<double>(to_multi_index,
literal.Get<double>(from_multi_index));
break;
+ case C64:
+ new_literal->Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
+ break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
<< PrimitiveType_Name(literal.shape().element_type());
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index f53553c701..9b0724262d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -111,17 +111,18 @@ class LiteralTestUtil {
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
const Literal& actual);
- // Returns whether the two tuples are equal.
- static ::testing::AssertionResult EqualTuple(
- const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
-
- // Expects that the values of the elements in the expected and actual tuples
- // are equal. Tuples are matched recursively.
- static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
-
// Asserts that the expected and actual literals are within the given error
// bound for all elements. Also, asserts that the rank, dimensions sizes, and
- // bounds are equivalent. Only supported for floating point values.
+ // bounds are equivalent.
+ //
+ // Tuples are matched recursively. When comparing tensors of
+ // non-floating-point type, checks for exact equality, ignoring the ErroSpec.
+ //
+ // If the shape of the literals is neither a complex/floating-point tensor nor
+ // a tuple which contains a complex/floating-point tensor, Near() is
+ // equivalent to Equal(). We don't raise an error in this case, because we
+ // want to allow callers to call Near() even if they have no preconceptions
+ // about the shapes being compared.
static ::testing::AssertionResult Near(
const Literal& expected, const Literal& actual,
const ErrorSpec& error) TF_MUST_USE_RESULT;
@@ -170,18 +171,6 @@ class LiteralTestUtil {
const Literal& actual,
const ErrorSpec& error);
- // Returns whether the values of the elements in the expected and actual
- // tuples are within the given error bound. Tuples are matched recursively.
- // If the elements of the tuple are not floating-point types, the error spec
- // is ignored and exact equality is checked.
- static ::testing::AssertionResult NearTuple(
- const Literal& expected, const Literal& actual,
- const ErrorSpec& error) TF_MUST_USE_RESULT;
-
- // Expects that the expected and actual values are near.
- static void ExpectNearTuple(const Literal& expected, const Literal& actual,
- const ErrorSpec& error);
-
// If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively.
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 569d5944ca..47cab79604 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -44,8 +44,7 @@ TEST_F(LocalClientAotTest, Constant) {
OpaqueData opaque_data{100, 20, 3};
void* parameters[] = {&opaque_data};
float out = 0;
- char tmp[4] = {0};
- void* temporary_buffers[] = {nullptr, &out, &tmp};
+ void* temporary_buffers[] = {nullptr, &out};
SumAndDouble(&out, &run_options, parameters, temporary_buffers);
EXPECT_EQ(out, 246.0f);
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 4d3b513b09..3704ddd801 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -87,10 +87,9 @@ int main(int argc, char** argv) {
// It's lame to hard-code the buffer assignments, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
- CHECK_EQ(result->buffer_sizes().size(), 3);
+ CHECK_EQ(result->buffer_sizes().size(), 2);
CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
- CHECK_EQ(result->buffer_sizes()[2], sizeof(float)); // temp buffer
if (triple.isOSBinFormatELF()) {
// Check the ELF magic.
CHECK_EQ(result->object_file_data()[0], 0x7F);
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 0fb87c3c2c..6c86dd5b9e 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -221,5 +221,77 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest,
::testing::Combine(::testing::Bool(), ::testing::Bool(),
::testing::Bool()));
+class MatOpsDotAddTest_bf16
+ : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<std::tuple<bool, bool, bool>> {};
+
+TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) {
+ bool row_major = std::get<0>(GetParam());
+ bool add_lhs = std::get<1>(GetParam());
+ bool transpose = std::get<2>(GetParam());
+ Array2D<bfloat16> lhs(
+ {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}});
+ Array2D<bfloat16> rhs(
+ {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}});
+
+ auto minor_to_major = [](bool row_major) -> std::vector<int64> {
+ return {row_major ? 1 : 0, row_major ? 0 : 1};
+ };
+
+ auto prim_type = primitive_util::NativeToPrimitiveType<bfloat16>();
+ Shape lhs_shape =
+ ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()});
+ Shape rhs_shape =
+ ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()});
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto lhs_handle,
+ client_->TransferToServer(
+ *Literal::CreateR2FromArray2DWithLayout<bfloat16>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto rhs_handle,
+ client_->TransferToServer(
+ *Literal::CreateR2FromArray2DWithLayout<bfloat16>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs");
+ auto lhs_mat_arg = lhs_arg;
+ if (transpose) {
+ lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0});
+ }
+ auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs");
+ auto result = builder.Dot(lhs_mat_arg, rhs_arg);
+ Array2D<bfloat16> expected;
+ if (add_lhs) {
+ result = builder.Add(result, lhs_arg);
+ if (transpose) {
+ expected = Array2D<bfloat16>(
+ {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}});
+ } else {
+ expected = Array2D<bfloat16>(
+ {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}});
+ }
+ } else {
+ result = builder.Add(result, rhs_arg);
+ if (transpose) {
+ expected = Array2D<bfloat16>(
+ {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}});
+ } else {
+ expected = Array2D<bfloat16>(
+ {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}});
+ }
+ }
+
+ ComputeAndCompareR2<bfloat16>(&builder, expected,
+ {lhs_handle.get(), rhs_handle.get()},
+ ErrorSpec(1e-6));
+}
+
+INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16,
+ ::testing::Combine(::testing::Bool(), ::testing::Bool(),
+ ::testing::Bool()));
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 6489eee9f3..6aafb9fa6c 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <limits>
#include <memory>
#include "tensorflow/compiler/xla/client/computation_builder.h"
@@ -36,36 +37,42 @@ namespace {
class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
- void UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims);
-
- template <typename T>
- void BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims);
+ std::unique_ptr<Literal> UniformTest(T a, T b,
+ tensorflow::gtl::ArraySlice<int64> dims,
+ int64 seed = 42);
// Computes the χ² statistic of a sample of the discrete uniform distribution
// of the given range size. `expected_count` is the number of times each
// possible value is expected to be generated. Thus, the sample size is
// `range_size * expected_count`.
- double UniformChiSquared(int32 range_size, int32 expected_count);
+ double UniformChiSquared(int32 range_size, int32 expected_count,
+ int64 seed = 42);
};
template <typename T>
-void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
+std::unique_ptr<Literal> PrngTest::UniformTest(
+ T a, T b, tensorflow::gtl::ArraySlice<int64> dims, int64 seed) {
ComputationBuilder builder(client_, TestName());
builder.RngUniform(
builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
- SetSeed(42);
+ SetSeed(seed);
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
actual->EachCell<T>([=](tensorflow::gtl::ArraySlice<int64>, T value) {
EXPECT_LE(a, value);
EXPECT_LT(value, b);
});
+ return actual;
}
// Uniform random number generation tests
XLA_TEST_F(PrngTest, ScalarU01) { UniformTest<float>(0, 1, {}); }
+XLA_TEST_F(PrngTest, ScalarU01limits) {
+ UniformTest<float>(std::numeric_limits<float>::min(),
+ std::numeric_limits<float>::max(), {});
+}
XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest<float>(0, 1, {0}); }
XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest<float>(0, 1, {10}); }
XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest<float>(3, 7, {10}); }
@@ -73,6 +80,56 @@ XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest<float>(0, 1, {0, 20}); }
XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
+// TODO(b/71543667): Fix Rng ops on LLVM backends.
+XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
+ DISABLED_ON_CPU(ScalarBF16Tests)))) {
+ for (int64 seed = 0; seed < 100; ++seed) {
+ // The largest negative number smaller than zero in bf16 that's not
+ // denormalized.
+ int32 low_raw = 0x80800000;
+ const float low = reinterpret_cast<const float&>(low_raw);
+ float high = 0.0f;
+ UniformTest<bfloat16>(static_cast<bfloat16>(low),
+ static_cast<bfloat16>(high), {}, /*seed=*/seed);
+
+ // Test odd and even values.
+ UniformTest<bfloat16>(static_cast<bfloat16>(32.75),
+ static_cast<bfloat16>(33), {}, /*seed=*/seed);
+ UniformTest<bfloat16>(static_cast<bfloat16>(32.50),
+ static_cast<bfloat16>(32.75), {}, /*seed=*/seed);
+ UniformTest<bfloat16>(static_cast<bfloat16>(-33.00),
+ static_cast<bfloat16>(-32.75), {}, /*seed=*/seed);
+ UniformTest<bfloat16>(static_cast<bfloat16>(-32.75),
+ static_cast<bfloat16>(-32.50), {}, /*seed=*/seed);
+ }
+}
+
+// TODO(b/71543667): Fix Rng ops on LLVM backends.
+XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(
+ DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) {
+ // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75,
+ // they should get similar counts.
+ bfloat16 low = static_cast<bfloat16>(32.25);
+ bfloat16 high = static_cast<bfloat16>(33);
+ bfloat16 interval = static_cast<bfloat16>(0.25);
+ std::vector<int32> counts(static_cast<int64>((high - low) / interval), 0);
+
+ constexpr int64 count = 100;
+ for (int64 seed = 0; seed < count; ++seed) {
+ auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
+ result->Literal::EachCell<bfloat16>(
+ [&](tensorflow::gtl::ArraySlice<int64>, bfloat16 value) {
+ int64 index = static_cast<int64>((value - low) / interval);
+ counts[index]++;
+ });
+ }
+ // Each bucket should have similar amount of counts. That is, not more than
+ // 10% of total counts. This mostly tests that we don't fall into a 1:2:2
+ // distribution, which yields 20% expected difference.
+ EXPECT_LT(std::abs(counts[0] - counts[1]), count * 0.1);
+ EXPECT_LT(std::abs(counts[1] - counts[2]), count * 0.1);
+}
+
namespace {
template <typename T>
T Square(T x) {
@@ -80,7 +137,8 @@ T Square(T x) {
}
} // namespace
-double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) {
+double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
+ int64 seed) {
int32 sample_size = range_size * expected_count;
ComputationBuilder builder(client_, TestName());
@@ -88,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) {
builder.ConstantR0<int32>(range_size),
ShapeUtil::MakeShape(S32, {sample_size}));
- SetSeed(42);
+ SetSeed(seed);
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
std::vector<int32> counts(range_size, 0);
actual->EachCell<int32>([&counts](tensorflow::gtl::ArraySlice<int64>,
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
new file mode 100644
index 0000000000..c0a2c0ca4c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -0,0 +1,132 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <array>
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+// Tests the Reduce HLO in ways that can't be done using the ComputationBuilder
+// API.
+
+namespace xla {
+namespace {
+
+namespace str_util = tensorflow::str_util;
+namespace strings = tensorflow::strings;
+
+struct ReduceLayout {
+ std::array<int64, 4> input_minor_to_major;
+ std::array<int64, 3> output_minor_to_major;
+
+ string ToString() const {
+ return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_",
+ str_util::Join(output_minor_to_major, "x"));
+ }
+};
+
+string PrintReduceLayout(
+ ::testing::TestParamInfo<ReduceLayout> reduce_layout_param) {
+ return reduce_layout_param.param.ToString();
+}
+
+void PrintTo(const ReduceLayout& reduce_layout, ::std::ostream* os) {
+ *os << reduce_layout.ToString();
+}
+
+class ReduceWithLayoutTest
+ : public HloTestBase,
+ public ::testing::WithParamInterface<ReduceLayout> {};
+
+StatusOr<std::unique_ptr<HloModule>> GetParsedModule() {
+ const char* const hlo_string = R"(
+HloModule BadReduce
+
+Sum {
+ x.1 = f32[] parameter(0)
+ y.1 = f32[] parameter(1)
+ ROOT add.1 = f32[] add(x.1, y.1)
+}
+
+ENTRY reduce.1 {
+ parameter = f32[2,2,2,3]{3,2,1,0} parameter(0)
+ init_value = f32[] constant(0)
+ reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, to_apply=Sum
+ ROOT copy = f32[2,2,3]{2,1,0} copy(reduce)
+}
+)";
+
+ return tools::Parse(hlo_string);
+}
+
+// TODO(b/72454718): XLA:GPU does not support executing code compiled without
+// optimizations.
+XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, GetParsedModule());
+ HloInstruction* reduce_instruction =
+ module->entry_computation()->root_instruction()->mutable_operand(0);
+ ASSERT_EQ(reduce_instruction->opcode(), HloOpcode::kReduce);
+
+ const ReduceLayout& reduce_layout = GetParam();
+
+ Shape* reduce_output_shape = reduce_instruction->mutable_shape();
+ *reduce_output_shape->mutable_layout() =
+ LayoutUtil::MakeLayout(reduce_layout.output_minor_to_major);
+
+ Shape* reduce_input_shape =
+ reduce_instruction->mutable_operand(0)->mutable_shape();
+ *reduce_input_shape->mutable_layout() =
+ LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
+
+ std::unique_ptr<Literal> reduce_input =
+ Literal::CreateR4<float>({{ /*i0=0*/
+ {/*i1=0*/
+ {-0.246092796, -0.179497838, -0.161181688},
+ {-0.151643038, -0.240213156, -0.198156}},
+ {/*i1=1*/
+ {-0.14222312, -0.162200093, -0.193907976},
+ {-0.239411, -0.198166847, -0.172471642}}},
+ { /*i0=1*/
+ {/*i1=0*/
+ {-0.22965157, -0.218723893, -0.129257083},
+ {-0.188762426, -0.16123569, -0.181166649}},
+ {/*i1=1*/
+ {-0.241772294, -0.245131493, -0.160247207},
+ {-0.179881215, -0.23383224, -0.121976733}}}});
+
+ EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
+}
+
+INSTANTIATE_TEST_CASE_P(ReduceWithLayoutTest_Instantiation,
+ ReduceWithLayoutTest,
+ ::testing::Values( //
+ ReduceLayout{{3, 2, 1, 0}, {0, 1, 2}}, //
+ ReduceLayout{{3, 2, 1, 0}, {0, 2, 1}}, //
+ ReduceLayout{{3, 2, 1, 0}, {1, 2, 0}}, //
+ ReduceLayout{{3, 2, 1, 0}, {1, 0, 2}}, //
+ ReduceLayout{{3, 2, 1, 0}, {2, 0, 1}}, //
+ ReduceLayout{{3, 2, 1, 0}, {2, 1, 0}}, //
+ ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}, //
+ ReduceLayout{{1, 2, 3, 0}, {1, 0, 2}}, //
+ ReduceLayout{{0, 2, 1, 3}, {2, 0, 1}}), //
+ PrintReduceLayout);
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 01f23efcd5..7f3c72671d 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -533,6 +533,7 @@ struct R4ReduceWindowTestData {
int64 strides[4];
int64 pad_low[4];
int64 pad_high[4];
+ int64 layout[4];
Reducer reducer;
};
@@ -548,7 +549,8 @@ string R4ReduceWindowTestDataToString(
"__strides_", tensorflow::str_util::Join(param.strides, "x"), //
"__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), //
"__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), //
- (param.reducer == kAdd) ? "add" : "max");
+ "__layout_", tensorflow::str_util::Join(param.layout, "_"), //
+ (param.reducer == kAdd) ? "_add" : "_max");
CHECK(param.reducer == kAdd || param.reducer == kMax);
// Test names are not allowed to contain the '-' character.
@@ -575,7 +577,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
param.base_bounds[2], param.base_bounds[3]);
input.FillIota(1);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4D(input);
+ Literal::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
ComputationDataHandle parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
@@ -611,8 +614,13 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
- {input_arg.get()}, DefaultErrorSpec());
+ std::unique_ptr<Literal> expected_literal =
+ Literal::CreateFromArray(*expected);
+ const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
+ input_literal->shape().element_type(),
+ AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
+ ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+ DefaultErrorSpec(), &expected_shape_with_layout);
}
};
@@ -626,6 +634,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// Arbitrary padding (not kSame or kValid).
@@ -634,6 +643,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{2, 2, 1, 1},
/*pad_low=*/{4, 4, 0, 0},
/*pad_high=*/{4, 4, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// Zero base bound edge case.
@@ -642,6 +652,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With non-1x1 window.
@@ -650,6 +661,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With max instead of add.
@@ -658,6 +670,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kMax},
// With stride.
@@ -666,6 +679,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{2, 4, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With low padding.
@@ -674,6 +688,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{2, 2, 1, 1},
/*pad_low=*/{3, 2, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With high padding.
@@ -682,6 +697,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{2, 2, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{2, 3, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// Window touches both sides of the padding simultaneously.
@@ -690,6 +706,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{1, 1, 0, 0},
/*pad_high=*/{1, 1, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// Window is entirely in the padding for some positions.
@@ -698,6 +715,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{4, 4, 0, 0},
/*pad_high=*/{4, 4, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// Zero base bound with padding edge case.
@@ -706,6 +724,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 1, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With stride, low padding and high padding.
@@ -714,6 +733,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{3, 1, 1, 1},
/*pad_low=*/{10, 1, 0, 0},
/*pad_high=*/{2, 3, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With second minor dimension == 9.
@@ -722,6 +742,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With minor dimension == 129.
@@ -730,6 +751,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With minor dims reduction and non-overlapped stride.
@@ -738,6 +760,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*strides=*/{1, 1, 2, 2},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
// With minor dims reduction and overlapped stride.
@@ -745,7 +768,8 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
/*window_bounds=*/{1, 1, 4, 4},
/*strides=*/{1, 1, 2, 2},
/*pad_low=*/{0, 0, 0, 0},
- /*pad_high=*/{0, 0, 0, 0},
+ /*pad_high=*/{1, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
};
@@ -762,10 +786,11 @@ XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
// Test cases that are large/slow/failed.
const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
- /*window_bounds=*/{3, 3, 1, 1},
- /*strides=*/{1, 1, 1, 1},
+ /*window_bounds=*/{3, 3, 1, 5},
+ /*strides=*/{1, 1, 1, 5},
/*pad_low=*/{1, 1, 0, 0},
/*pad_high=*/{1, 1, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kMax},
R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
@@ -773,6 +798,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
/*strides=*/{2, 2, 1, 1},
/*pad_low=*/{0, 0, 0, 0},
/*pad_high=*/{1, 1, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
/*reducer=*/kAdd},
};
@@ -782,6 +808,54 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(use_bfloat16_params)),
R4ReduceWindowTestDataToString);
+class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {};
+
+// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
+XLA_TEST_P(R4ReduceWindowAnyDimsTest,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
+ DoIt();
+}
+
+const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = {
+ R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
+ /*window_bounds=*/{2, 3, 4, 5},
+ /*strides=*/{1, 1, 1, 1},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
+ /*reducer=*/kAdd},
+ R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
+ /*window_bounds=*/{2, 3, 1, 1},
+ /*strides=*/{1, 1, 1, 1},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{3, 2, 1, 0},
+ /*reducer=*/kMax},
+ // With 0321 layout.
+ R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
+ /*window_bounds=*/{2, 3, 4, 5},
+ /*strides=*/{1, 2, 3, 4},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{0, 3, 2, 1},
+ /*reducer=*/kAdd},
+
+ // With 0123 layout.
+ R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23},
+ /*window_bounds=*/{2, 3, 7, 9},
+ /*strides=*/{1, 2, 5, 8},
+ /*pad_low=*/{0, 0, 0, 0},
+ /*pad_high=*/{0, 0, 0, 0},
+ /*layout=*/{0, 1, 2, 3},
+ /*reducer=*/kAdd},
+};
+
+INSTANTIATE_TEST_CASE_P(
+ R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest,
+ ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues),
+ ::testing::ValuesIn(use_bfloat16_params)),
+ R4ReduceWindowTestDataToString);
+
struct R3ReduceWindowTestData {
int64 base_bounds[3];
int64 window_bounds[3];
@@ -942,37 +1016,39 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
::testing::tuple<R2ReduceWindowTestData, bool>> {
protected:
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
-};
-TEST_P(R2ReduceWindowTest, Add) {
- ComputationBuilder b(client_, TestName());
- const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
+ void DoIt() {
+ ComputationBuilder b(client_, TestName());
+ const auto& param = ::testing::get<0>(GetParam());
+ CHECK(param.reducer == kAdd);
- const float kInitValue = 0.0f;
- Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ const float kInitValue = 0.0f;
+ Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
+ std::unique_ptr<Literal> input_literal =
+ Literal::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
- ComputationDataHandle parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
- auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
- b.ReduceWindow(/*operand=*/parameter,
- /*init_value=*/init_value,
- /*computation=*/CreateScalarAddComputation(FloatType(), &b),
- /*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/param.padding);
+ ComputationDataHandle parameter;
+ auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ &b, &parameter);
+ auto init_value =
+ CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ b.ReduceWindow(/*operand=*/parameter,
+ /*init_value=*/init_value,
+ /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+ /*window_dimensions=*/param.window_bounds,
+ /*window_strides=*/param.strides, /*padding=*/param.padding);
- auto expected = ReferenceUtil::ReduceWindow2DAdd(
- /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
- /*stride=*/param.strides, /*padding=*/param.padding);
+ auto expected = ReferenceUtil::ReduceWindow2DAdd(
+ /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
+ /*stride=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
- {input_arg.get()}, DefaultErrorSpec());
-}
+ ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ {input_arg.get()}, DefaultErrorSpec());
+ }
+};
+
+TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
@@ -980,6 +1056,26 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(use_bfloat16_params)),
R2ReduceWindowTestDataToString);
+class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
+
+// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
+XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
+ DoIt();
+}
+
+const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
+ {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
+ /*strides=*/{1, 1}, /*layout=*/{1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+};
+
+INSTANTIATE_TEST_CASE_P(
+ R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
+ ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
+ ::testing::ValuesIn(use_bfloat16_params)),
+ R2ReduceWindowTestDataToString);
+
struct R1ReduceWindowTestData {
int64 base_bounds[1];
int64 window_bounds[1];
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index 62ff349e9c..9ee94b8571 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -39,8 +39,8 @@ namespace xla {
namespace {
struct SelectAndScatterTestParam {
- Array4D<float> operand_shape;
- Array4D<float> source_shape;
+ std::vector<int64> operand_shape;
+ std::vector<int64> source_shape;
Padding padding_type;
tensorflow::gtl::ArraySlice<int64> window_dimensions;
tensorflow::gtl::ArraySlice<int64> window_strides;
@@ -69,83 +69,132 @@ class SelectAndScatterTest
Computation min_f32_;
};
-XLA_TEST_P(SelectAndScatterTest, R4Randomized) {
- Array4D<float> o(GetParam().operand_shape);
+XLA_TEST_P(SelectAndScatterTest, ParamTest) {
+ auto operand_shape = GetParam().operand_shape;
+ Array<float> o(operand_shape);
o.FillRandom(1.5f);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = builder_.ConstantFromArray(o);
- Array4D<float> s(GetParam().source_shape);
+ auto source_shape = GetParam().source_shape;
+ Array<float> s(source_shape);
s.FillRandom(12.0f);
- auto source = builder_.ConstantR4FromArray4D(s);
-
- builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
- GetParam().window_strides, GetParam().padding_type,
- source, builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto source = builder_.ConstantFromArray(s);
- auto e = ReferenceUtil::SelectAndScatter4DGePlus(
- o, s, 0.0f, GetParam().window_dimensions, GetParam().window_strides,
- GetParam().padding_type == Padding::kSame);
+ auto select_and_scatter = builder_.SelectAndScatter(
+ operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides,
+ GetParam().padding_type, source, builder_.ConstantR0<float>(0.0f),
+ add_f32_);
- ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-5));
+ ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5));
}
INSTANTIATE_TEST_CASE_P(
SelectAndScatterTest_Instantiation, SelectAndScatterTest,
- ::testing::Values(SelectAndScatterTestParam{{6, 6, 256, 128},
- {3, 3, 256, 128},
- Padding::kSame,
- {3, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{7, 7, 256, 128},
- {3, 3, 256, 128},
- Padding::kValid,
- {3, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{6, 7, 256, 128},
- {3, 3, 256, 128},
- Padding::kValid,
- {2, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{6, 7, 256, 128},
- {2, 3, 256, 128},
- Padding::kValid,
- {2, 3, 1, 1},
- {3, 2, 1, 1}},
- SelectAndScatterTestParam{{9, 9, 16, 128},
- {3, 3, 16, 128},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{3, 3, 4, 4},
- {1, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{3, 3, 4, 4},
- {1, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{9, 3, 4, 4},
- {3, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{7, 3, 4, 4},
- {3, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {2, 3, 1, 1}},
- SelectAndScatterTestParam{{1, 1, 5, 5},
- {1, 1, 5, 5},
- Padding::kSame,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{7, 7, 8, 256},
- {4, 4, 8, 256},
- Padding::kSame,
- {2, 2, 1, 1},
- {2, 2, 1, 1}}));
+ ::testing::Values(
+ SelectAndScatterTestParam{{6, 6, 6, 4, 4},
+ {3, 3, 3, 4, 4},
+ Padding::kSame,
+ {3, 3, 3, 1, 1},
+ {2, 2, 2, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 7, 4, 4},
+ {3, 3, 3, 4, 4},
+ Padding::kValid,
+ {3, 3, 3, 1, 1},
+ {2, 2, 2, 1, 1}},
+
+ SelectAndScatterTestParam{{8, 8, 8, 4, 4},
+ {1, 3, 3, 4, 4},
+ Padding::kValid,
+ {8, 4, 4, 1, 1},
+ {1, 2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 6, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kSame,
+ {3, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 7, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kValid,
+ {2, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 7, 256, 128},
+ {2, 3, 256, 128},
+ Padding::kValid,
+ {2, 3, 1, 1},
+ {3, 2, 1, 1}},
+ SelectAndScatterTestParam{{9, 9, 16, 128},
+ {3, 3, 16, 128},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{3, 3, 4, 4},
+ {1, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{3, 3, 4, 4},
+ {1, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{9, 3, 4, 4},
+ {3, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{7, 3, 4, 4},
+ {3, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {2, 3, 1, 1}},
+ SelectAndScatterTestParam{{1, 1, 5, 5},
+ {1, 1, 5, 5},
+ Padding::kSame,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 8, 256},
+ {4, 4, 8, 256},
+ Padding::kSame,
+ {2, 2, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{
+ {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{
+ {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{{7, 256, 128},
+ {3, 256, 128},
+ Padding::kValid,
+ {3, 1, 1},
+ {2, 1, 1}},
+ SelectAndScatterTestParam{{6, 256, 128},
+ {3, 256, 128},
+ Padding::kValid,
+ {2, 1, 1},
+ {2, 1, 1}},
+ SelectAndScatterTestParam{{6, 256, 128},
+ {2, 256, 128},
+ Padding::kValid,
+ {2, 1, 1},
+ {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{
+ {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}}));
// Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 8b10aef5b8..0e90a32358 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -34,7 +34,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal) {
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> indices) {
// Generate a random uniforma number from -0.0625 and 0.0625 and bias it
- // with a position dependent nubmer with mean 0.037109375. These number
+ // with a position dependent number with mean 0.037109375. These number
// should allow for long chains of accumulation without being too close
// to zero or to large to accumulate all numbers accurately.
return (generator(engine) - 1.0625) +
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index fa4192e928..835e2d7e55 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -215,5 +215,23 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
}
+XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({0, 1});
+ auto rhs = builder.ConstantR1<int32>({1, 1});
+ builder.ConvertElementType(builder.Eq(lhs, rhs), S32);
+
+ ComputeAndCompareR1<int32>(&builder, {0, 1}, {});
+}
+
+XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({0, 1});
+ auto rhs = builder.ConstantR1<int32>({1, 1});
+ builder.ConvertElementType(builder.Eq(lhs, rhs), F32);
+
+ ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 146fbadcb6..9ad2a19853 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -19,12 +19,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -32,6 +34,7 @@ limitations under the License.
namespace xla {
namespace {
namespace se = ::perftools::gputools;
+namespace gtl = ::tensorflow::gtl;
class HloProfileTest : public ClientLibraryTestBase {};
@@ -43,39 +46,74 @@ struct ParsedProfileOutputLine {
string trops;
string bytes_per_sec;
string bytes_per_cycle;
- string name;
+ string opcode;
};
-StatusOr<ParsedProfileOutputLine> ParseProfileOutputLine(const string& line,
- bool expect_flops,
- bool expect_trops) {
+::testing::AssertionResult HasFlops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'flops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'flops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+}
+
+::testing::AssertionResult HasTrops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'trops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'trops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+}
+
+Status ParseOneProfileOutputLine(
+ const string& line, bool expect_hlo,
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
string match_usecs = "([0-9.]+) usec";
- string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "(<none>)";
- string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "(<none>)";
+ string match_flops = "([^ ]+)";
+ string match_trops = "([^ ]+)";
string match_bytes_per_sec = "([0-9.TGMKi]+)B/s";
string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle";
+
+ // The underlined part is what we're trying to match with match_opcode:
+ //
+ // %dot33 = f32[256,256]{1,0} dot(...)
+ // ^^^
+
+ string match_opcode =
+ expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])";
string regexp_pattern = tensorflow::strings::StrCat(
" +", match_cycles, separator, match_usecs, separator, match_flops,
separator, match_trops, separator, match_bytes_per_sec, separator,
- match_bytes_per_cycle, separator, "(.*)");
+ match_bytes_per_cycle, separator, match_opcode);
- RE2 pattern(regexp_pattern);
ParsedProfileOutputLine parsed_line;
bool matched = RE2::FullMatch(
- line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
+ line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
&parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
&parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
- &parsed_line.name);
+ &parsed_line.opcode);
if (!matched) {
return tensorflow::errors::InvalidArgument(
"Input did not match regexp. Input: ", line,
", Regexp: ", regexp_pattern);
}
- return parsed_line;
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+
+ return Status::OK();
}
// Returns void so that we can ASSERT.
@@ -110,7 +148,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
Executable* executable = local_executable->executable();
HloExecutionProfile hlo_execution_profile(
- &executable->hlo_profile_printer(), &executable->hlo_profile_index_map());
+ &executable->hlo_profile_printer_data(),
+ &executable->hlo_profile_index_map());
TF_ASSERT_OK_AND_ASSIGN(
Backend::StreamPtr stream_ptr,
@@ -147,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
ClientLibrary::GetOrCreateLocalClient(platform));
ComputationBuilder builder(client, TestName());
- auto result = builder.Tanh(builder.Dot(
+ auto result = builder.Tanh(builder.Add(
builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
@@ -160,31 +199,43 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
std::vector<string> profile_output_lines =
tensorflow::str_util::Split(profile_output, '\n');
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_profile,
- ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true,
- /*expect_trops=*/true));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine tanh_profile,
- ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false,
- /*expect_trops=*/true));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
+ MaybeFind(parsed_profile_lines, "tanh"));
EXPECT_GT(total_profile.cycles, 0);
EXPECT_EQ(total_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(total_profile));
+ EXPECT_TRUE(HasTrops(total_profile));
+
EXPECT_GT(total_profile.cycles, dot_profile.cycles);
EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(dot_profile));
+ EXPECT_FALSE(HasTrops(dot_profile));
+
EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
+
+ EXPECT_FALSE(HasFlops(tanh_profile));
+ EXPECT_TRUE(HasTrops(tanh_profile));
}
// TODO(b/71364943): This test exposes a bug in the parallel CPU backend.
@@ -219,7 +270,7 @@ XLA_TEST_F(HloProfileTest,
auto matrix = builder.GetTupleElement(state, 1);
auto next_iteration = builder.Add(builder.GetTupleElement(state, 0),
builder.ConstantR0<int32>(1));
- builder.Tuple({next_iteration, builder.Dot(matrix, matrix)});
+ builder.Tuple({next_iteration, builder.Add(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
@@ -248,20 +299,23 @@ XLA_TEST_F(HloProfileTest,
ASSERT_NE(while_body_profile_start, profile_output_lines.end());
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_while_body_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
+ /*expect_hlo=*/false, &parsed_profile_lines));
+
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
+ /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
EXPECT_GT(total_while_body_profile.cycles, 0);
- EXPECT_EQ(total_while_body_profile.name, "[total]");
+ EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 1c68e271e0..42e7f91f26 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -931,7 +931,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
- shape, operands[0], config ? *config : ""));
+ operands[0]->shape(), operands[0], config ? *config : ""));
break;
}
case HloOpcode::kRng: {
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index fe5d29a6b6..b020905035 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -30,9 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/stacktrace.h"
namespace xla {
-namespace {
-// Logs the provided status message with a backtrace.
Status WithLogBacktrace(const Status& status) {
CHECK(!status.ok());
VLOG(1) << status.ToString();
@@ -40,8 +38,6 @@ Status WithLogBacktrace(const Status& status) {
return status;
}
-} // namespace
-
ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
: enabled(enabled), label(label) {
if (enabled) {
@@ -74,13 +70,18 @@ Status AppendStatus(Status prior, tensorflow::StringPiece context) {
// Implementation note: we can't common these out (without using macros) because
// they all need to va_start/va_end their varargs in their frame.
-Status InvalidArgument(const char* format, ...) {
+Status InvalidArgumentV(const char* format, va_list args) {
string message;
+ tensorflow::strings::Appendv(&message, format, args);
+ return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
+}
+
+Status InvalidArgument(const char* format, ...) {
va_list args;
va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
+ Status result = InvalidArgumentV(format, args);
va_end(args);
- return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
+ return result;
}
Status Unimplemented(const char* format, ...) {
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index bb2db2010c..4bc2d632cd 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -40,6 +40,13 @@ limitations under the License.
namespace xla {
+// Logs the provided status message with a backtrace.
+//
+// For use by Status-factories, logs a backtrace at the point where the status
+// is created, such that we can use --vmodule=util=1 to see all status
+// creation backtraces.
+Status WithLogBacktrace(const Status& status);
+
// Ranks greater than 8 are very rare, so use InlinedVector<int64, 8> to store
// the bounds and indices. And for the rare cases of ranks greater than 8,
// the InlinedVector will just behave like an std::vector<> and allocate the
@@ -207,6 +214,9 @@ Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+// Passed-varargs variant of the InvalidArgument factory above.
+Status InvalidArgumentV(const char* format, va_list args);
+
// Splits the lines of the original, replaces leading whitespace with the prefix
// given by "indentation", and returns the string joined by newlines again. As a
// side effect, any additional trailing whitespace is removed.
@@ -398,13 +408,11 @@ std::vector<std::pair<int64, int64>> CommonFactors(
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);
-// Simple wrapper around std::all_of.
template <typename Container, typename Predicate>
bool c_all_of(Container container, Predicate predicate) {
return std::all_of(std::begin(container), std::end(container), predicate);
}
-// Simple wrapper around std::transform.
template <typename InputContainer, typename OutputIterator,
typename UnaryOperation>
OutputIterator c_transform(InputContainer input_container,
@@ -414,7 +422,6 @@ OutputIterator c_transform(InputContainer input_container,
output_iterator, unary_op);
}
-// Simple wrapper around std::copy_if.
template <class InputContainer, class OutputIterator, class UnaryPredicate>
OutputIterator c_copy_if(InputContainer input_container,
OutputIterator output_iterator,
@@ -423,6 +430,11 @@ OutputIterator c_copy_if(InputContainer input_container,
output_iterator, predicate);
}
+template <class InputContainer, class Comparator>
+void c_sort(InputContainer& input_container, Comparator comparator) {
+ std::sort(input_container.begin(), input_container.end(), comparator);
+}
+
} // namespace xla
#define XLA_LOG_LINES(SEV, STRING) \
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index 224eb2a20c..55f42ed3a4 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -25,6 +26,26 @@ limitations under the License.
namespace xla {
namespace window_util {
+Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) {
+ Window window;
+ for (int64 size : sizes) {
+ auto* dimension = window.add_dimensions();
+ dimension->set_size(size);
+ dimension->set_stride(1);
+ }
+ return window;
+}
+
+PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes) {
+ PaddingConfig config;
+ for (int64 size : sizes) {
+ auto* dimension = config.add_dimensions();
+ dimension->set_edge_padding_low(size);
+ dimension->set_edge_padding_high(size);
+ }
+ return config;
+}
+
/* static */ string ToString(const WindowDimension& dim) {
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
@@ -114,13 +135,21 @@ bool HasPadding(const Window& window) {
return false;
}
-bool HasEvenPadding(const Window& window) {
+bool HasSymmetricPadding(const Window& window) {
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
[](const WindowDimension& dim) {
return dim.padding_low() == dim.padding_high();
});
}
+bool HasSymmetricPadding(const PaddingConfig& padding_config) {
+ return std::all_of(padding_config.dimensions().begin(),
+ padding_config.dimensions().end(),
+ [](const PaddingConfig::PaddingConfigDimension& dim) {
+ return dim.edge_padding_low() == dim.edge_padding_high();
+ });
+}
+
bool HasNegativePadding(const Window& window) {
return std::any_of(window.dimensions().begin(), window.dimensions().end(),
[](const WindowDimension& dim) {
diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h
index 17c388fc0b..ba473e2c8c 100644
--- a/tensorflow/compiler/xla/window_util.h
+++ b/tensorflow/compiler/xla/window_util.h
@@ -18,10 +18,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace window_util {
+// Creates a window with the given sizes in the dimensions and all strides set
+// to 1.
+Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes);
+
+// Creates a padding config with symmetrical padding in each dimension, of value
+// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
+// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
+// side of dimension 1, and two pixels of padding symmetrically on dimension 2.
+PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice<int64> sizes);
+
string ToString(const WindowDimension& dim);
string ToString(const Window& window);
@@ -32,9 +43,14 @@ string ToString(const Window& window);
bool HasStride(const Window& window);
bool HasPadding(const Window& window);
-bool HasEvenPadding(const Window& window);
+bool HasSymmetricPadding(const Window& window);
bool HasNegativePadding(const Window& window);
+// As with HasSymmetricPadding(Window) above, returns whether the "padding low"
+// is equivalent to the "padding high" for all dimensions, but works on a
+// padding configuration.
+bool HasSymmetricPadding(const PaddingConfig& padding_config);
+
bool HasBaseDilation(const Window& window);
bool HasWindowDilation(const Window& window);
bool HasDilation(const Window& window);
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index fda1a4c27b..e1ed08c848 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -179,6 +179,10 @@ message DebugOptions {
// ops.
bool xla_gpu_use_cudnn_batchnorm = 94;
+ // Dump compilation artifacts, before hlo passes are executed, in binary proto
+ // into this directory.
+ string xla_dump_prepass_hlo_proto_to = 95;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index e5c3017426..d4c0660285 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -114,7 +114,6 @@ cc_library(
name = "contrib_kernels",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/batching:batch_ops_kernels",
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels",
@@ -137,7 +136,6 @@ cc_library(
name = "contrib_ops_op_lib",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/batching:batch_ops_op_lib",
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index f600a8a998..8f6a3cb1ca 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# Add projects here, they will show up under tf.contrib.
+from tensorflow.contrib import batching
from tensorflow.contrib import bayesflow
from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h
index 2b43939f14..665304b5ee 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.h
+++ b/tensorflow/contrib/android/asset_manager_filesystem.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
+#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
+#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
@@ -79,4 +79,4 @@ class AssetManagerFileSystem : public FileSystem {
};
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
+#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index cd98f0e703..ee67909133 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -67,48 +67,14 @@ load(
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-tf_custom_op_library(
- name = "python/ops/_batch_ops.so",
- srcs = ["ops/batch_ops.cc"],
- deps = [
- "//tensorflow/contrib/batching/kernels:batch_kernels",
- ],
-)
-
-tf_gen_op_libs(
- op_lib_names = ["batch_ops"],
-)
-
-tf_gen_op_wrapper_py(
- name = "batch_ops",
- deps = [":batch_ops_op_lib"],
-)
-
-tf_kernel_library(
- name = "batch_ops_kernels",
- deps = [
- "//tensorflow/contrib/batching/kernels:batch_kernels",
- "//tensorflow/contrib/batching/util:periodic_function",
- "//tensorflow/core/kernels:concat_lib",
- "//tensorflow/core/kernels:ops_util",
- "//tensorflow/core/kernels:split_lib",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
+py_library(
name = "batch_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
- dso = [":python/ops/_batch_ops.so"],
- kernels = [
- ":batch_ops_kernels",
- ":batch_ops_op_lib",
- ],
srcs_version = "PY2AND3",
deps = [
- ":batch_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:batch_ops_gen",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
@@ -118,6 +84,14 @@ tf_custom_op_py_library(
],
)
+cc_library(
+ name = "batch_ops_kernels",
+ deps = [
+ "//tensorflow/core/kernels:batch_kernels",
+ ],
+ alwayslink = 1,
+)
+
py_test(
name = "batch_ops_test",
size = "small",
@@ -133,6 +107,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:script_ops",
],
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
index 60861f83f4..86250e6692 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
index 63ba8fcf45..d9b37da693 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ b/tensorflow/contrib/batching/basic_batch_scheduler.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
index 3afce2761f..8e94e1fd8b 100644
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ b/tensorflow/contrib/batching/batch_scheduler.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/kernels/BUILD b/tensorflow/contrib/batching/kernels/BUILD
deleted file mode 100644
index 6e53dd9a5f..0000000000
--- a/tensorflow/contrib/batching/kernels/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-# Description:
-# Contains kernels for the batching ops.
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-cc_library(
- name = "batch_kernels",
- srcs = ["batch_kernels.cc"],
- deps = [
- "//tensorflow/contrib/batching:shared_batch_scheduler_hdrs",
- "//tensorflow/contrib/batching/util:periodic_function_dynamic",
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/kernels:concat_lib_hdrs",
- "//tensorflow/core/kernels:ops_util_hdrs",
- "//tensorflow/core/kernels:split_lib_hdrs",
- ],
- alwayslink = 1,
-)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/batching/ops/batch_ops.cc b/tensorflow/contrib/batching/ops/batch_ops.cc
deleted file mode 100644
index 85e0ccba4a..0000000000
--- a/tensorflow/contrib/batching/ops/batch_ops.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-REGISTER_OP("Batch")
- .Input("in_tensors: T")
- .Output("batched_tensors: T")
- .Output("batch_index: int64")
- .Output("id: int64")
- .Attr("num_batch_threads: int")
- .Attr("max_batch_size: int")
- .Attr("batch_timeout_micros: int")
- .Attr("allowed_batch_sizes: list(int) = []")
- .Attr("grad_timeout_micros: int")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("batching_queue: string = ''")
- .Attr("T: list(type)")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<shape_inference::ShapeHandle> in_shapes;
- TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes));
- std::vector<shape_inference::ShapeHandle> out_shapes(in_shapes.size());
- for (int i = 0; i < in_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i]));
- }
- TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes));
- TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()}));
- TF_RETURN_IF_ERROR(c->set_output(
- "batch_index",
- {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()),
- shape_inference::DimensionOrConstant(3)})}));
- return Status::OK();
- })
- .Doc(R"doc(
-Batches all input tensors nondeterministically.
-
-When many instances of this Op are being run concurrently with the same
-container/shared_name in the same device, some will output zero-shaped Tensors
-and others will output Tensors of size up to max_batch_size.
-
-All Tensors in in_tensors are batched together (so, for example, labels and
-features should be batched with a single instance of this operation.
-
-Each invocation of batch emits an `id` scalar which will be used to identify
-this particular invocation when doing unbatch or its gradient.
-
-Each op which emits a non-empty batch will also emit a non-empty batch_index
-Tensor, which, is a [K, 3] matrix where each row contains the invocation's id,
-start, and length of elements of each set of Tensors present in batched_tensors.
-
-Batched tensors are concatenated along the first dimension, and all tensors in
-in_tensors must have the first dimension of the same size.
-
-in_tensors: The tensors to be batched.
-num_batch_threads: Number of scheduling threads for processing batches of work.
- Determines the number of batches processed in parallel.
-max_batch_size: Batch sizes will never be bigger than this.
-batch_timeout_micros: Maximum number of microseconds to wait before outputting
- an incomplete batch.
-allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does
- nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
- batches up to one of those sizes. The entries must increase monotonically, and
- the final entry must equal max_batch_size.
-grad_timeout_micros: The timeout to use for the gradient. See Unbatch.
-batched_tensors: Either empty tensors or a batch of concatenated Tensors.
-batch_index: If out_tensors is non-empty, has information to invert it.
-container: Controls the scope of sharing of this batch.
-id: always contains a scalar with a unique ID for this invocation of Batch.
-shared_name: Concurrently running instances of batch in the same device with the
- same container and shared_name will batch their elements together. If left
- empty, the op name will be used as the shared name.
-T: the types of tensors to be batched.
-)doc");
-
-REGISTER_OP("Unbatch")
- .Input("batched_tensor: T")
- .Input("batch_index: int64")
- .Input("id: int64")
- .Output("unbatched_tensor: T")
- .Attr("timeout_micros: int")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("T: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle out_shape;
- TF_RETURN_IF_ERROR(
- c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape));
- c->set_output(0, out_shape);
- return Status::OK();
- })
- .Doc(R"doc(
-Reverses the operation of Batch for a single output Tensor.
-
-An instance of Unbatch either receives an empty batched_tensor, in which case it
-asynchronously waits until the values become available from a concurrently
-running instance of Unbatch with the same container and shared_name, or receives
-a non-empty batched_tensor in which case it finalizes all other concurrently
-running instances and outputs its own element from the batch.
-
-batched_tensor: The possibly transformed output of Batch. The size of the first
- dimension should remain unchanged by the transformations for the operation to
- work.
-batch_index: The matching batch_index obtained from Batch.
-id: The id scalar emitted by Batch.
-unbatched_tensor: The Tensor corresponding to this execution.
-timeout_micros: Maximum amount of time (in microseconds) to wait to receive the
- batched input tensor associated with a given invocation of the op.
-container: Container to control resource sharing.
-shared_name: Instances of Unbatch with the same container and shared_name are
- assumed to possibly belong to the same batch. If left empty, the op name will
- be used as the shared name.
-)doc");
-
-REGISTER_OP("UnbatchGrad")
- .Input("original_input: T")
- .Input("batch_index: int64")
- .Input("grad: T")
- .Input("id: int64")
- .Output("batched_grad: T")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("T: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2))));
- return Status::OK();
- })
- .Doc(R"doc(
-Gradient of Unbatch.
-
-Acts like Batch but using the given batch_index index of batching things as they
-become available. This ensures that the gradients are propagated back in the
-same session which did the forward pass.
-
-original_input: The input to the Unbatch operation this is the gradient of.
-batch_index: The batch_index given to the Unbatch operation this is the gradient
-of.
-grad: The downstream gradient.
-id: The id scalar emitted by Batch.
-batched_grad: The return value, either an empty tensor or the batched gradient.
-container: Container to control resource sharing.
-shared_name: Instances of UnbatchGrad with the same container and shared_name
- are assumed to possibly belong to the same batch. If left empty, the op name
- will be used as the shared name.
- )doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index cee4d7b4a9..4e0b3f9af9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -18,18 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.batching.ops import gen_batch_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
-from tensorflow.contrib.batching.ops.gen_batch_ops import *
+from tensorflow.python.ops.gen_batch_ops import *
# pylint: enable=wildcard-import
-from tensorflow.contrib.util import loader
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import resource_loader
-
-
-_batch_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("_batch_ops.so"))
@ops.RegisterGradient("Batch")
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
index 7eb1e20c42..83a59695d7 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/shared_batch_scheduler.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h
index ced27a8833..40a39a5569 100644
--- a/tensorflow/contrib/batching/test_util/fake_clock_env.h
+++ b/tensorflow/contrib/batching/test_util/fake_clock_env.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
+#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h
index fb61bc2eea..aa2ed0a385 100644
--- a/tensorflow/contrib/batching/util/periodic_function.h
+++ b/tensorflow/contrib/batching/util/periodic_function.h
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
+#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
+#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
+#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py
index 57f44aef1a..750afb6654 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py
@@ -18,15 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib
from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.platform import test
@@ -46,7 +52,7 @@ class Counter(object):
class MockDistribution(independent_lib.Independent):
- """Monitors DenseVariational calls to the underlying distribution."""
+ """Monitors layer calls to the underlying distribution."""
def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
self.result_sample = result_sample
@@ -104,11 +110,14 @@ class ConvVariational(test.TestCase):
def _testKLPenaltyKernel(self, layer_class):
with self.test_session():
layer = layer_class(filters=2, kernel_size=3)
- if layer_class == prob_layers_lib.Conv1DVariational:
+ if layer_class in (prob_layers_lib.Conv1DReparameterization,
+ prob_layers_lib.Conv1DFlipout):
inputs = random_ops.random_uniform([2, 3, 1], seed=1)
- elif layer_class == prob_layers_lib.Conv2DVariational:
+ elif layer_class in (prob_layers_lib.Conv2DReparameterization,
+ prob_layers_lib.Conv2DFlipout):
inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
- elif layer_class == prob_layers_lib.Conv3DVariational:
+ elif layer_class in (prob_layers_lib.Conv3DReparameterization,
+ prob_layers_lib.Conv3DFlipout):
inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
# No keys.
@@ -133,11 +142,14 @@ class ConvVariational(test.TestCase):
kernel_size=3,
bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
bias_prior_fn=_make_normal)
- if layer_class == prob_layers_lib.Conv1DVariational:
+ if layer_class in (prob_layers_lib.Conv1DReparameterization,
+ prob_layers_lib.Conv1DFlipout):
inputs = random_ops.random_uniform([2, 3, 1], seed=1)
- elif layer_class == prob_layers_lib.Conv2DVariational:
+ elif layer_class in (prob_layers_lib.Conv2DReparameterization,
+ prob_layers_lib.Conv2DFlipout):
inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
- elif layer_class == prob_layers_lib.Conv3DVariational:
+ elif layer_class in (prob_layers_lib.Conv3DReparameterization,
+ prob_layers_lib.Conv3DFlipout):
inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
# No keys.
@@ -152,42 +164,78 @@ class ConvVariational(test.TestCase):
self.assertEqual(len(losses), 2)
self.assertListEqual(layer.losses, losses)
- def _testConvVariational(self, layer_class):
+ def _testConvSetUp(self, layer_class, batch_size, depth=None,
+ height=None, width=None, channels=None, filters=None,
+ **kwargs):
+ seed = Counter()
+ if layer_class in (prob_layers_lib.Conv1DReparameterization,
+ prob_layers_lib.Conv1DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, width, channels], seed=seed())
+ kernel_size = (2,)
+ elif layer_class in (prob_layers_lib.Conv2DReparameterization,
+ prob_layers_lib.Conv2DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, height, width, channels], seed=seed())
+ kernel_size = (2, 2)
+ elif layer_class in (prob_layers_lib.Conv3DReparameterization,
+ prob_layers_lib.Conv3DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, depth, height, width, channels], seed=seed())
+ kernel_size = (2, 2, 2)
+
+ kernel_shape = kernel_size + (channels, filters)
+ kernel_posterior = MockDistribution(
+ loc=random_ops.random_uniform(kernel_shape, seed=seed()),
+ scale=random_ops.random_uniform(kernel_shape, seed=seed()),
+ result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
+ result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
+ kernel_prior = MockDistribution(
+ result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
+ result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
+ kernel_divergence = MockKLDivergence(
+ result=random_ops.random_uniform(kernel_shape, seed=seed()))
+
+ bias_size = (filters,)
+ bias_posterior = MockDistribution(
+ result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
+ result_sample=random_ops.random_uniform(bias_size, seed=seed()))
+ bias_prior = MockDistribution(
+ result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
+ result_sample=random_ops.random_uniform(bias_size, seed=seed()))
+ bias_divergence = MockKLDivergence(
+ result=random_ops.random_uniform(bias_size, seed=seed()))
+
+ layer = layer_class(
+ filters=filters,
+ kernel_size=kernel_size,
+ padding="SAME",
+ kernel_posterior_fn=lambda *args: kernel_posterior,
+ kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
+ kernel_prior_fn=lambda *args: kernel_prior,
+ kernel_divergence_fn=kernel_divergence,
+ bias_posterior_fn=lambda *args: bias_posterior,
+ bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
+ bias_prior_fn=lambda *args: bias_prior,
+ bias_divergence_fn=bias_divergence,
+ **kwargs)
+
+ outputs = layer(inputs)
+
+ kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ return (kernel_posterior, kernel_prior, kernel_divergence,
+ bias_posterior, bias_prior, bias_divergence,
+ layer, inputs, outputs, kl_penalty, kernel_shape)
+
+ def _testConvReparameterization(self, layer_class):
batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
with self.test_session() as sess:
- seed = Counter()
- if layer_class == prob_layers_lib.Conv1DVariational:
- inputs = random_ops.random_uniform(
- [batch_size, width, channels], seed=seed())
- kernel_size = (2,)
- elif layer_class == prob_layers_lib.Conv2DVariational:
- inputs = random_ops.random_uniform(
- [batch_size, height, width, channels], seed=seed())
- kernel_size = (2, 2)
- elif layer_class == prob_layers_lib.Conv3DVariational:
- inputs = random_ops.random_uniform(
- [batch_size, depth, height, width, channels], seed=seed())
- kernel_size = (2, 2, 2)
-
- kernel_shape = kernel_size + (channels, filters)
- kernel_posterior = MockDistribution(
- result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
- result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
- kernel_prior = MockDistribution(
- result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
- result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
- kernel_divergence = MockKLDivergence(
- result=random_ops.random_uniform(kernel_shape, seed=seed()))
-
- bias_size = (filters,)
- bias_posterior = MockDistribution(
- result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
- result_sample=random_ops.random_uniform(bias_size, seed=seed()))
- bias_prior = MockDistribution(
- result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
- result_sample=random_ops.random_uniform(bias_size, seed=seed()))
- bias_divergence = MockKLDivergence(
- result=random_ops.random_uniform(bias_size, seed=seed()))
+ (kernel_posterior, kernel_prior, kernel_divergence,
+ bias_posterior, bias_prior, bias_divergence, layer, inputs,
+ outputs, kl_penalty, kernel_shape) = self._testConvSetUp(
+ layer_class, batch_size,
+ depth=depth, height=height, width=width, channels=channels,
+ filters=filters)
convolution_op = nn_ops.Convolution(
tensor_shape.TensorShape(inputs.shape),
@@ -198,23 +246,6 @@ class ConvVariational(test.TestCase):
bias_posterior.result_sample,
data_format="NHWC")
- layer = layer_class(
- filters=filters,
- kernel_size=kernel_size,
- padding="SAME",
- kernel_posterior_fn=lambda *args: kernel_posterior,
- kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
- kernel_prior_fn=lambda *args: kernel_prior,
- kernel_divergence_fn=kernel_divergence,
- bias_posterior_fn=lambda *args: bias_posterior,
- bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
- bias_prior_fn=lambda *args: bias_prior,
- bias_divergence_fn=bias_divergence)
-
- outputs = layer(inputs)
-
- kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
-
[
expected_outputs_, actual_outputs_,
expected_kernel_, actual_kernel_,
@@ -257,32 +288,233 @@ class ConvVariational(test.TestCase):
bias_posterior.result_sample]],
bias_divergence.args)
- def testKLPenaltyKernelConv1DVariational(self):
- self._testKLPenaltyKernel(prob_layers_lib.Conv1DVariational)
+ def _testConvFlipout(self, layer_class):
+ batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
+ with self.test_session() as sess:
+ (kernel_posterior, kernel_prior, kernel_divergence,
+ bias_posterior, bias_prior, bias_divergence, layer, inputs,
+ outputs, kl_penalty, kernel_shape) = self._testConvSetUp(
+ layer_class, batch_size,
+ depth=depth, height=height, width=width, channels=channels,
+ filters=filters, seed=44)
+
+ convolution_op = nn_ops.Convolution(
+ tensor_shape.TensorShape(inputs.shape),
+ filter_shape=tensor_shape.TensorShape(kernel_shape),
+ padding="SAME")
+
+ expected_kernel_posterior_affine = normal_lib.Normal(
+ loc=array_ops.zeros_like(kernel_posterior.result_loc),
+ scale=kernel_posterior.result_scale)
+ expected_kernel_posterior_affine_tensor = (
+ expected_kernel_posterior_affine.sample(seed=42))
+
+ expected_outputs = convolution_op(
+ inputs, kernel_posterior.distribution.loc)
+
+ input_shape = array_ops.shape(inputs)
+ output_shape = array_ops.shape(expected_outputs)
+ batch_shape = array_ops.expand_dims(input_shape[0], 0)
+ channels = input_shape[-1]
+ rank = len(inputs.get_shape()) - 2
+
+ sign_input = random_ops.random_uniform(
+ array_ops.concat([batch_shape,
+ array_ops.expand_dims(channels, 0)], 0),
+ minval=0,
+ maxval=2,
+ dtype=dtypes.int32,
+ seed=layer.seed)
+ sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype)
+ sign_output = random_ops.random_uniform(
+ array_ops.concat([batch_shape,
+ array_ops.expand_dims(filters, 0)], 0),
+ minval=0,
+ maxval=2,
+ dtype=dtypes.int32,
+ seed=distribution_util.gen_new_seed(
+ layer.seed, salt="conv_flipout"))
+ sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype)
+ for _ in range(rank):
+ sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C)
+ sign_output = array_ops.expand_dims(sign_output, 1)
+
+ sign_input = array_ops.tile( # tile for element-wise op broadcasting
+ sign_input,
+ [1] + [input_shape[i + 1] for i in range(rank)] + [1])
+ sign_output = array_ops.tile(
+ sign_output,
+ [1] + [output_shape[i + 1] for i in range(rank)] + [1])
+
+ perturbed_inputs = convolution_op(
+ inputs * sign_input, expected_kernel_posterior_affine_tensor)
+ perturbed_inputs *= sign_output
+
+ expected_outputs += perturbed_inputs
+ expected_outputs = nn.bias_add(expected_outputs,
+ bias_posterior.result_sample,
+ data_format="NHWC")
+
+ [
+ expected_outputs_, actual_outputs_,
+ expected_kernel_divergence_, actual_kernel_divergence_,
+ expected_bias_, actual_bias_,
+ expected_bias_divergence_, actual_bias_divergence_,
+ ] = sess.run([
+ expected_outputs, outputs,
+ kernel_divergence.result, kl_penalty[0],
+ bias_posterior.result_sample, layer.bias_posterior_tensor,
+ bias_divergence.result, kl_penalty[1],
+ ])
+
+ self.assertAllClose(
+ expected_bias_, actual_bias_,
+ rtol=1e-6, atol=0.)
+ self.assertAllClose(
+ expected_outputs_, actual_outputs_,
+ rtol=1e-6, atol=0.)
+ self.assertAllClose(
+ expected_kernel_divergence_, actual_kernel_divergence_,
+ rtol=1e-6, atol=0.)
+ self.assertAllClose(
+ expected_bias_divergence_, actual_bias_divergence_,
+ rtol=1e-6, atol=0.)
+
+ self.assertAllEqual(
+ [[kernel_posterior.distribution, kernel_prior.distribution, None]],
+ kernel_divergence.args)
+
+ self.assertAllEqual(
+ [[bias_posterior.distribution,
+ bias_prior.distribution,
+ bias_posterior.result_sample]],
+ bias_divergence.args)
+
+ def _testRandomConvFlipout(self, layer_class):
+ batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
+ with self.test_session() as sess:
+ seed = Counter()
+ if layer_class in (prob_layers_lib.Conv1DReparameterization,
+ prob_layers_lib.Conv1DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, width, channels], seed=seed())
+ kernel_size = (2,)
+ elif layer_class in (prob_layers_lib.Conv2DReparameterization,
+ prob_layers_lib.Conv2DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, height, width, channels], seed=seed())
+ kernel_size = (2, 2)
+ elif layer_class in (prob_layers_lib.Conv3DReparameterization,
+ prob_layers_lib.Conv3DFlipout):
+ inputs = random_ops.random_uniform(
+ [batch_size, depth, height, width, channels], seed=seed())
+ kernel_size = (2, 2, 2)
+
+ kernel_shape = kernel_size + (channels, filters)
+ bias_size = (filters,)
+
+ kernel_posterior = MockDistribution(
+ loc=random_ops.random_uniform(
+ kernel_shape, seed=seed()),
+ scale=random_ops.random_uniform(
+ kernel_shape, seed=seed()),
+ result_log_prob=random_ops.random_uniform(
+ kernel_shape, seed=seed()),
+ result_sample=random_ops.random_uniform(
+ kernel_shape, seed=seed()))
+ bias_posterior = MockDistribution(
+ loc=random_ops.random_uniform(
+ bias_size, seed=seed()),
+ scale=random_ops.random_uniform(
+ bias_size, seed=seed()),
+ result_log_prob=random_ops.random_uniform(
+ bias_size, seed=seed()),
+ result_sample=random_ops.random_uniform(
+ bias_size, seed=seed()))
+ layer_one = layer_class(
+ filters=filters,
+ kernel_size=kernel_size,
+ padding="SAME",
+ kernel_posterior_fn=lambda *args: kernel_posterior,
+ kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
+ bias_posterior_fn=lambda *args: bias_posterior,
+ bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
+ seed=44)
+ layer_two = layer_class(
+ filters=filters,
+ kernel_size=kernel_size,
+ padding="SAME",
+ kernel_posterior_fn=lambda *args: kernel_posterior,
+ kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
+ bias_posterior_fn=lambda *args: bias_posterior,
+ bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
+ seed=45)
+
+ outputs_one = layer_one(inputs)
+ outputs_two = layer_two(inputs)
+
+ outputs_one_, outputs_two_ = sess.run([
+ outputs_one, outputs_two])
+
+ self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)),
+ np.prod(outputs_one_.shape))
+
+ def testKLPenaltyKernelConv1DReparameterization(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization)
+
+ def testKLPenaltyKernelConv2DReparameterization(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization)
+
+ def testKLPenaltyKernelConv3DReparameterization(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization)
+
+ def testKLPenaltyKernelConv1DFlipout(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout)
+
+ def testKLPenaltyKernelConv2DFlipout(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout)
+
+ def testKLPenaltyKernelConv3DFlipout(self):
+ self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout)
+
+ def testKLPenaltyBothConv1DReparameterization(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization)
+
+ def testKLPenaltyBothConv2DReparameterization(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization)
+
+ def testKLPenaltyBothConv3DReparameterization(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization)
+
+ def testKLPenaltyBothConv1DFlipout(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout)
+
+ def testKLPenaltyBothConv2DFlipout(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout)
- def testKLPenaltyKernelConv2DVariational(self):
- self._testKLPenaltyKernel(prob_layers_lib.Conv2DVariational)
+ def testKLPenaltyBothConv3DFlipout(self):
+ self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout)
- def testKLPenaltyKernelConv3DVariational(self):
- self._testKLPenaltyKernel(prob_layers_lib.Conv3DVariational)
+ def testConv1DReparameterization(self):
+ self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization)
- def testKLPenaltyBothConv1DVariational(self):
- self._testKLPenaltyBoth(prob_layers_lib.Conv1DVariational)
+ def testConv2DReparameterization(self):
+ self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization)
- def testKLPenaltyBothConv2DVariational(self):
- self._testKLPenaltyBoth(prob_layers_lib.Conv2DVariational)
+ def testConv3DReparameterization(self):
+ self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization)
- def testKLPenaltyBothConv3DVariational(self):
- self._testKLPenaltyBoth(prob_layers_lib.Conv3DVariational)
+ def testConv1DFlipout(self):
+ self._testConvFlipout(prob_layers_lib.Conv1DFlipout)
- def testConv1DVariational(self):
- self._testConvVariational(prob_layers_lib.Conv1DVariational)
+ def testConv2DFlipout(self):
+ self._testConvFlipout(prob_layers_lib.Conv2DFlipout)
- def testConv2DVariational(self):
- self._testConvVariational(prob_layers_lib.Conv2DVariational)
+ def testConv3DFlipout(self):
+ self._testConvFlipout(prob_layers_lib.Conv3DFlipout)
- def testConv3DVariational(self):
- self._testConvVariational(prob_layers_lib.Conv3DVariational)
+ def testRandomConv1DFlipout(self):
+ self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
index 4e9f119351..342f38ccec 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib
+from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib
from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py
index 93412afae7..a742b7c1aa 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers.py
@@ -24,24 +24,36 @@ from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import *
-from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import *
+from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import *
from tensorflow.contrib.bayesflow.python.ops.layers_util import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
- 'Convolution1DVariational',
- 'Convolution2DVariational',
- 'Convolution3DVariational',
- 'Conv1DVariational',
- 'Conv2DVariational',
- 'Conv3DVariational',
- 'convolution1d_variational',
- 'convolution2d_variational',
- 'convolution3d_variational',
- 'conv1d_variational',
- 'conv2d_variational',
- 'conv3d_variational',
+ 'Convolution1DReparameterization',
+ 'Convolution2DReparameterization',
+ 'Convolution3DReparameterization',
+ 'Convolution1DFlipout',
+ 'Convolution2DFlipout',
+ 'Convolution3DFlipout',
+ 'Conv1DReparameterization',
+ 'Conv2DReparameterization',
+ 'Conv3DReparameterization',
+ 'Conv1DFlipout',
+ 'Conv2DFlipout',
+ 'Conv3DFlipout',
+ 'convolution1d_reparameterization',
+ 'convolution2d_reparameterization',
+ 'convolution3d_reparameterization',
+ 'convolution1d_flipout',
+ 'convolution2d_flipout',
+ 'convolution3d_flipout',
+ 'conv1d_reparameterization',
+ 'conv2d_reparameterization',
+ 'conv3d_reparameterization',
+ 'conv1d_flipout',
+ 'conv2d_flipout',
+ 'conv3d_flipout',
'DenseReparameterization',
'DenseLocalReparameterization',
'DenseFlipout',
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
index 6ffb55feb1..7723cfb442 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import util as distribution_util
class _ConvVariational(layers_lib.Layer):
@@ -123,8 +124,6 @@ class _ConvVariational(layers_lib.Layer):
dilation_rate: Dilation rate for an atrous convolution.
activation: Activation function (`callable`).
activity_regularizer: Regularizer function for the output.
- kernel_use_local_reparameterization: Python `bool` indicating whether
- `kernel` calculation should employ the Local Reparameterization Trick.
kernel_posterior_fn: `callable` returning posterior.
kernel_posterior_tensor_fn: `callable` operating on posterior.
kernel_prior_fn: `callable` returning prior.
@@ -271,12 +270,6 @@ class _ConvVariational(layers_lib.Layer):
self._built_bias_divergence = True
return outputs
- def _apply_variational_kernel(self, inputs):
- self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
- self.kernel_posterior)
- outputs = self._convolution_op(inputs, self.kernel_posterior_tensor)
- return outputs
-
def _apply_variational_bias(self, inputs):
if self.bias_posterior is None:
self.bias_posterior_tensor = None
@@ -356,7 +349,165 @@ class _ConvVariational(layers_lib.Layer):
new_space)
-class Conv1DVariational(_ConvVariational):
+class _ConvReparameterization(_ConvVariational):
+ """Abstract nD convolution layer (private, used as implementation base).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ name: A string, the name of the layer.
+
+ Properties:
+ rank: Python integer, dimensionality of convolution.
+ filters: Python integer, dimensionality of the output space.
+ kernel_size: Size of the convolution window.
+ strides: Stride length of convolution.
+ padding: Python string describing padding approach.
+ data_format: Python string describing input data's dimensions.
+ dilation_rate: Dilation rate for an atrous convolution.
+ activation: Activation function (`callable`).
+ activity_regularizer: Regularizer function for the output.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
+ """
+
+ def __init__(
+ self,
+ rank,
+ filters,
+ kernel_size,
+ strides=1,
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=1,
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ name=None,
+ **kwargs):
+ super(_ConvReparameterization, self).__init__(
+ rank=rank,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ name=name, **kwargs)
+
+ def _apply_variational_kernel(self, inputs):
+ self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn(
+ self.kernel_posterior)
+ self.kernel_posterior_affine = None
+ self.kernel_posterior_affine_tensor = None
+ outputs = self._convolution_op(inputs, self.kernel_posterior_tensor)
+ return outputs
+
+
+class Conv1DReparameterization(_ConvReparameterization):
"""1D convolution layer (e.g. temporal convolution).
This layer creates a convolution kernel that is convolved
@@ -370,7 +521,9 @@ class Conv1DVariational(_ConvVariational):
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -442,8 +595,6 @@ class Conv1DVariational(_ConvVariational):
dilation_rate: Dilation rate for an atrous convolution.
activation: Activation function (`callable`).
activity_regularizer: Regularizer function for the output.
- kernel_use_local_reparameterization: Python `bool` indicating whether
- `kernel` calculation should employ the Local Reparameterization Trick.
kernel_posterior_fn: `callable` returning posterior.
kernel_posterior_tensor_fn: `callable` operating on posterior.
kernel_prior_fn: `callable` returning prior.
@@ -463,12 +614,12 @@ class Conv1DVariational(_ConvVariational):
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 128, 1])
- net = tfp.layers.Conv1DVariational(64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)(net)
+ net = tfp.layers.Conv1DReparameterization(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
net = tf.reshape(net, [-1, 128 * 64])
- logits = tfp.layers.DenseVariational(10)(net)
+ logits = tfp.layers.DenseReparameterization(10)(net)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -482,6 +633,10 @@ class Conv1DVariational(_ConvVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
def __init__(
@@ -506,7 +661,7 @@ class Conv1DVariational(_ConvVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
- super(Conv1DVariational, self).__init__(
+ super(Conv1DReparameterization, self).__init__(
rank=1,
filters=filters,
kernel_size=kernel_size,
@@ -528,7 +683,7 @@ class Conv1DVariational(_ConvVariational):
name=name, **kwargs)
-def conv1d_variational(
+def conv1d_reparameterization(
inputs,
filters,
kernel_size,
@@ -563,7 +718,9 @@ def conv1d_variational(
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -645,13 +802,13 @@ def conv1d_variational(
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 128, 1])
- net = tfp.layers.conv1d_variational(net,
- 64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)
+ net = tfp.layers.conv1d_reparameterization(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
net = tf.reshape(net, [-1, 128 * 64])
- logits = tfp.layers.dense_variational(net, 10)
+ logits = tfp.layers.dense_reparameterization(net, 10)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -665,8 +822,12 @@ def conv1d_variational(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
- layer = Conv1DVariational(
+ layer = Conv1DReparameterization(
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -691,7 +852,7 @@ def conv1d_variational(
return layer.apply(inputs)
-class Conv2DVariational(_ConvVariational):
+class Conv2DReparameterization(_ConvReparameterization):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
@@ -705,7 +866,9 @@ class Conv2DVariational(_ConvVariational):
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -784,8 +947,6 @@ class Conv2DVariational(_ConvVariational):
dilation_rate: Dilation rate for an atrous convolution.
activation: Activation function (`callable`).
activity_regularizer: Regularizer function for the output.
- kernel_use_local_reparameterization: Python `bool` indicating whether
- `kernel` calculation should employ the Local Reparameterization Trick.
kernel_posterior_fn: `callable` returning posterior.
kernel_posterior_tensor_fn: `callable` operating on posterior.
kernel_prior_fn: `callable` returning prior.
@@ -805,15 +966,15 @@ class Conv2DVariational(_ConvVariational):
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 32, 32, 3])
- net = tfp.layers.Conv2DVariational(64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)(net)
+ net = tfp.layers.Conv2DReparameterization(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
net = tf.layers.MaxPooling2D(pool_size=2,
strides=2,
padding="SAME")(net)
net = tf.reshape(net, [-1, 8 * 8 * 64])
- logits = tfp.layers.DenseVariational(10)(net)
+ logits = tfp.layers.DenseReparameterization(10)(net)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -827,6 +988,10 @@ class Conv2DVariational(_ConvVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
def __init__(
@@ -851,7 +1016,7 @@ class Conv2DVariational(_ConvVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
- super(Conv2DVariational, self).__init__(
+ super(Conv2DReparameterization, self).__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
@@ -873,7 +1038,7 @@ class Conv2DVariational(_ConvVariational):
name=name, **kwargs)
-def conv2d_variational(
+def conv2d_reparameterization(
inputs,
filters,
kernel_size,
@@ -908,7 +1073,9 @@ def conv2d_variational(
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -997,17 +1164,17 @@ def conv2d_variational(
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 32, 32, 3])
- net = tfp.layers.conv2d_variational(net,
- 64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)
+ net = tfp.layers.conv2d_reparameterization(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
net = tf.layers.max_pooling2d(net,
pool_size=2,
strides=2,
padding="SAME")
net = tf.reshape(net, [-1, 8 * 8 * 64])
- logits = tfp.layers.dense_variational(net, 10)
+ logits = tfp.layers.dense_reparameterization(net, 10)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -1021,8 +1188,12 @@ def conv2d_variational(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
- layer = Conv2DVariational(
+ layer = Conv2DReparameterization(
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -1047,7 +1218,7 @@ def conv2d_variational(
return layer.apply(inputs)
-class Conv3DVariational(_ConvVariational):
+class Conv3DReparameterization(_ConvReparameterization):
"""3D convolution layer (e.g. spatial convolution over volumes).
This layer creates a convolution kernel that is convolved
@@ -1061,7 +1232,9 @@ class Conv3DVariational(_ConvVariational):
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -1141,8 +1314,6 @@ class Conv3DVariational(_ConvVariational):
dilation_rate: Dilation rate for an atrous convolution.
activation: Activation function (`callable`).
activity_regularizer: Regularizer function for the output.
- kernel_use_local_reparameterization: Python `bool` indicating whether
- `kernel` calculation should employ the Local Reparameterization Trick.
kernel_posterior_fn: `callable` returning posterior.
kernel_posterior_tensor_fn: `callable` operating on posterior.
kernel_prior_fn: `callable` returning prior.
@@ -1162,15 +1333,15 @@ class Conv3DVariational(_ConvVariational):
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 256, 32, 32, 3])
- net = tfp.layers.Conv3DVariational(64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)(net)
+ net = tfp.layers.Conv3DReparameterization(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
net = tf.layers.MaxPooling2D(pool_size=2,
strides=2,
padding="SAME")(net)
net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
- logits = tfp.layers.DenseVariational(10)(net)
+ logits = tfp.layers.DenseReparameterization(10)(net)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -1184,6 +1355,10 @@ class Conv3DVariational(_ConvVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
def __init__(
@@ -1208,7 +1383,7 @@ class Conv3DVariational(_ConvVariational):
bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
name=None,
**kwargs):
- super(Conv3DVariational, self).__init__(
+ super(Conv3DReparameterization, self).__init__(
rank=3,
filters=filters,
kernel_size=kernel_size,
@@ -1230,7 +1405,7 @@ class Conv3DVariational(_ConvVariational):
name=name, **kwargs)
-def conv3d_variational(
+def conv3d_reparameterization(
inputs,
filters,
kernel_size,
@@ -1265,7 +1440,9 @@ def conv3d_variational(
```none
outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
```
- where f denotes the layer's calculation.
+ where f denotes the layer's calculation. It uses the reparameterization
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`.
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
@@ -1355,17 +1532,17 @@ def conv3d_variational(
tfp = tf.contrib.bayesflow
net = tf.reshape(features, [-1, 256, 32, 32, 3])
- net = tfp.layers.conv3d_variational(net,
- 64,
- kernel_size=5,
- padding="SAME",
- activation=tf.nn.relu)
+ net = tfp.layers.conv3d_reparameterization(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
net = tf.layers.max_pooling2d(net,
pool_size=2,
strides=2,
padding="SAME")
net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
- logits = tfp.layers.dense_variational(net, 10)
+ logits = tfp.layers.dense_reparameterization(net, 10)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
@@ -1379,8 +1556,1352 @@ def conv3d_variational(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
+ """
+ layer = Conv3DReparameterization(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
+ _reuse=reuse)
+ return layer.apply(inputs)
+
+
+class _ConvFlipout(_ConvVariational):
+ """Abstract nD convolution layer (private, used as implementation base).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ length of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+
+ Properties:
+ rank: Python integer, dimensionality of convolution.
+ filters: Python integer, dimensionality of the output space.
+ kernel_size: Size of the convolution window.
+ strides: Stride length of convolution.
+ padding: Python string describing padding approach.
+ data_format: Python string describing input data's dimensions.
+ dilation_rate: Dilation rate for an atrous convolution.
+ activation: Activation function (`callable`).
+ activity_regularizer: Regularizer function for the output.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+ seed: Python integer, used to create random seeds.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+
+ def __init__(
+ self,
+ rank,
+ filters,
+ kernel_size,
+ strides=1,
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=1,
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ **kwargs):
+ super(_ConvFlipout, self).__init__(
+ rank=rank,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ name=name, **kwargs)
+ self.seed = seed
+
+ def _apply_variational_kernel(self, inputs):
+ if (not isinstance(self.kernel_posterior, independent_lib.Independent) or
+ not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)):
+ raise TypeError(
+ "`{}` requires "
+ "`kernel_posterior_fn` produce an instance of "
+ "`tf.distributions.Independent(tf.distributions.Normal)` "
+ "(saw: \"{}\").".format(
+ type(self).__name__, self.kernel_posterior.name))
+ self.kernel_posterior_affine = normal_lib.Normal(
+ loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc),
+ scale=self.kernel_posterior.distribution.scale)
+ self.kernel_posterior_affine_tensor = (
+ self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
+ self.kernel_posterior_tensor = None
+
+ outputs = self._convolution_op(
+ inputs, self.kernel_posterior.distribution.loc)
+
+ input_shape = array_ops.shape(inputs)
+ output_shape = array_ops.shape(outputs)
+ batch_shape = array_ops.expand_dims(input_shape[0], 0)
+ channels = input_shape[-1]
+
+ sign_input = layers_util.random_sign(
+ array_ops.concat([batch_shape,
+ array_ops.expand_dims(channels, 0)], 0),
+ dtype=inputs.dtype,
+ seed=self.seed)
+ sign_output = layers_util.random_sign(
+ array_ops.concat([batch_shape,
+ array_ops.expand_dims(self.filters, 0)], 0),
+ dtype=inputs.dtype,
+ seed=distribution_util.gen_new_seed(
+ self.seed, salt="conv_flipout"))
+ for _ in range(self.rank):
+ sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C)
+ sign_output = array_ops.expand_dims(sign_output, 1)
+
+ sign_input = array_ops.tile( # tile for element-wise op broadcasting
+ sign_input,
+ [1] + [input_shape[i + 1] for i in range(self.rank)] + [1])
+ sign_output = array_ops.tile(
+ sign_output,
+ [1] + [output_shape[i + 1] for i in range(self.rank)] + [1])
+
+ perturbed_inputs = self._convolution_op(
+ inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output
+
+ outputs += perturbed_inputs
+ return outputs
+
+
+class Conv1DFlipout(_ConvFlipout):
+ """1D convolution layer (e.g. temporal convolution) with Flipout.
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of a single integer, specifying the
+ length of the 1D convolution window.
+ strides: An integer or tuple/list of a single integer,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: An integer or tuple/list of a single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+
+ Properties:
+ filters: Python integer, dimensionality of the output space.
+ kernel_size: Size of the convolution window.
+ strides: Stride length of convolution.
+ padding: Python string describing padding approach.
+ data_format: Python string describing input data's dimensions.
+ dilation_rate: Dilation rate for an atrous convolution.
+ activation: Activation function (`callable`).
+ activity_regularizer: Regularizer function for the output.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+ seed: Python integer, used to create random seeds.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 128, 1])
+ net = tfp.layers.Conv1DFlipout(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
+ net = tf.reshape(net, [-1, 128 * 64])
+ logits = tfp.layers.DenseFlipout(10)(net)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+
+ def __init__(
+ self,
+ filters,
+ kernel_size,
+ strides=1,
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=1,
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ **kwargs):
+ super(Conv1DFlipout, self).__init__(
+ rank=1,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
+ name=name, **kwargs)
+
+
+def conv1d_flipout(
+ inputs,
+ filters,
+ kernel_size,
+ strides=1,
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=1,
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ reuse=None):
+ """Functional interface for 1D convolution layer (e.g. temporal convolution).
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ inputs: Tensor input.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of a single integer, specifying the
+ length of the 1D convolution window.
+ strides: An integer or tuple/list of a single integer,
+ specifying the stride length of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: An integer or tuple/list of a single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 128, 1])
+ net = tfp.layers.conv1d_flipout(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
+ net = tf.reshape(net, [-1, 128 * 64])
+ logits = tfp.layers.dense_flipout(net, 10)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+ layer = Conv1DFlipout(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
+ name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
+ _reuse=reuse)
+ return layer.apply(inputs)
+
+
+class Conv2DFlipout(_ConvFlipout):
+ """2D convolution layer (e.g. spatial convolution over images) with Flipout.
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ height and width of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
+
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+
+ Properties:
+ filters: Python integer, dimensionality of the output space.
+ kernel_size: Size of the convolution window.
+ strides: Stride length of convolution.
+ padding: Python string describing padding approach.
+ data_format: Python string describing input data's dimensions.
+ dilation_rate: Dilation rate for an atrous convolution.
+ activation: Activation function (`callable`).
+ activity_regularizer: Regularizer function for the output.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+ seed: Python integer, used to create random seeds.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 32, 32, 3])
+ net = tfp.layers.Conv2DFlipout(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
+ net = tf.layers.MaxPooling2D(pool_size=2,
+ strides=2,
+ padding="SAME")(net)
+ net = tf.reshape(net, [-1, 8 * 8 * 64])
+ logits = tfp.layers.DenseFlipout(10)(net)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+
+ def __init__(
+ self,
+ filters,
+ kernel_size,
+ strides=(1, 1),
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=(1, 1),
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ **kwargs):
+ super(Conv2DFlipout, self).__init__(
+ rank=2,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
+ name=name, **kwargs)
+
+
+def conv2d_flipout(
+ inputs,
+ filters,
+ kernel_size,
+ strides=(1, 1),
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=(1, 1),
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ reuse=None):
+ """Functional interface for the 2D convolution layer.
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ inputs: Tensor input.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ height and width of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
+
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 32, 32, 3])
+ net = tfp.layers.conv2d_flipout(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
+ net = tf.layers.max_pooling2d(net,
+ pool_size=2,
+ strides=2,
+ padding="SAME")
+ net = tf.reshape(net, [-1, 8 * 8 * 64])
+ logits = tfp.layers.dense_flipout(net, 10)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+ layer = Conv2DFlipout(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
+ name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
+ _reuse=reuse)
+ return layer.apply(inputs)
+
+
+class Conv3DFlipout(_ConvFlipout):
+ """3D convolution layer (e.g. spatial convolution over volumes) with Flipout.
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
+ depth, height and width of the 3D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 3 integers,
+ specifying the strides of the convolution along the depth,
+ height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, depth, height, width, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, depth, height, width)`.
+ dilation_rate: An integer or tuple/list of 3 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+
+ Properties:
+ filters: Python integer, dimensionality of the output space.
+ kernel_size: Size of the convolution window.
+ strides: Stride length of convolution.
+ padding: Python string describing padding approach.
+ data_format: Python string describing input data's dimensions.
+ dilation_rate: Dilation rate for an atrous convolution.
+ activation: Activation function (`callable`).
+ activity_regularizer: Regularizer function for the output.
+ kernel_posterior_fn: `callable` returning posterior.
+ kernel_posterior_tensor_fn: `callable` operating on posterior.
+ kernel_prior_fn: `callable` returning prior.
+ kernel_divergence_fn: `callable` returning divergence.
+ bias_posterior_fn: `callable` returning posterior.
+ bias_posterior_tensor_fn: `callable` operating on posterior.
+ bias_prior_fn: `callable` returning prior.
+ bias_divergence_fn: `callable` returning divergence.
+ seed: Python integer, used to create random seeds.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 256, 32, 32, 3])
+ net = tfp.layers.Conv3DFlipout(64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)(net)
+ net = tf.layers.MaxPooling2D(pool_size=2,
+ strides=2,
+ padding="SAME")(net)
+ net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
+ logits = tfp.layers.DenseFlipout(10)(net)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
+ """
+
+ def __init__(
+ self,
+ filters,
+ kernel_size,
+ strides=(1, 1, 1),
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=(1, 1, 1),
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ **kwargs):
+ super(Conv3DFlipout, self).__init__(
+ rank=3,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ kernel_posterior_fn=kernel_posterior_fn,
+ kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
+ kernel_prior_fn=kernel_prior_fn,
+ kernel_divergence_fn=kernel_divergence_fn,
+ bias_posterior_fn=bias_posterior_fn,
+ bias_posterior_tensor_fn=bias_posterior_tensor_fn,
+ bias_prior_fn=bias_prior_fn,
+ bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
+ name=name, **kwargs)
+
+
+def conv3d_flipout(
+ inputs,
+ filters,
+ kernel_size,
+ strides=(1, 1, 1),
+ padding="valid",
+ data_format="channels_last",
+ dilation_rate=(1, 1, 1),
+ activation=None,
+ activity_regularizer=None,
+ trainable=True,
+ kernel_posterior_fn=layers_util.default_mean_field_normal_fn(),
+ kernel_posterior_tensor_fn=lambda d: d.sample(),
+ kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda
+ loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)),
+ kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long
+ bias_posterior_tensor_fn=lambda d: d.sample(),
+ bias_prior_fn=None,
+ bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
+ seed=None,
+ name=None,
+ reuse=None):
+ """Functional interface for the 3D convolution layer.
+
+ This layer creates a convolution kernel that is convolved
+ (actually cross-correlated) with the layer input to produce a tensor of
+ outputs. It may also include a bias addition and activation function
+ on the outputs. It assumes the `kernel` and/or `bias` are drawn from
+ distributions.
+
+ By default, the layer implements a stochastic forward pass via
+ sampling from the kernel and bias posteriors,
+ ```none
+ outputs = f(inputs; kernel, bias), kernel, bias ~ posterior
+ ```
+ where f denotes the layer's calculation. It uses the Flipout
+ estimator [1], which performs a Monte Carlo approximation of the
+ distribution integrating over the `kernel` and `bias`. Flipout uses
+ roughly twice as many floating point operations as the
+ reparameterization estimator but has the advantage of significantly
+ lower variance.
+
+ The arguments permit separate specification of the surrogate posterior
+ (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
+ distributions.
+
+ Arguments:
+ inputs: Tensor input.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
+ depth, height and width of the 3D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 3 integers,
+ specifying the strides of the convolution along the depth,
+ height and width.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, depth, height, width, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, depth, height, width)`.
+ dilation_rate: An integer or tuple/list of 3 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ activity_regularizer: Optional regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ kernel_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `kernel` parameter. Default value:
+ `default_mean_field_normal_fn()`.
+ kernel_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ kernel_prior_fn: Python `callable` which creates `tf.distributions`
+ instance. See `default_mean_field_normal_fn` docstring for required
+ parameter signature.
+ Default value: `tf.distributions.Normal(loc=0., scale=1.)`.
+ kernel_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ sample is a `Tensor`.
+ bias_posterior_fn: Python `callable` which creates
+ `tf.distributions.Distribution` instance representing the surrogate
+ posterior of the `bias` parameter. Default value:
+ `default_mean_field_normal_fn(is_singular=True)` (which creates an
+ instance of `tf.distributions.Deterministic`).
+ bias_posterior_tensor_fn: Python `callable` which takes a
+ `tf.distributions.Distribution` instance and returns a representative
+ value. Default value: `lambda d: d.sample()`.
+ bias_prior_fn: Python `callable` which creates `tf.distributions` instance.
+ See `default_mean_field_normal_fn` docstring for required parameter
+ signature. Default value: `None` (no prior, no variational inference)
+ bias_divergence_fn: Python `callable` which takes the surrogate posterior
+ distribution, prior distribution and random variate sample(s) from the
+ surrogate posterior and computes or approximates the KL divergence. The
+ distributions are `tf.distributions.Distribution`-like instances and the
+ seed: Python scalar `int` which initializes the random number
+ generator. Default value: `None` (i.e., use global seed).
+ name: A string, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
+
+ #### Examples
+
+ We illustrate a Bayesian neural network with [variational inference](
+ https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
+ assuming a dataset of `features` and `labels`.
+
+ ```python
+ tfp = tf.contrib.bayesflow
+
+ net = tf.reshape(features, [-1, 256, 32, 32, 3])
+ net = tfp.layers.conv3d_flipout(net,
+ filters=64,
+ kernel_size=5,
+ padding="SAME",
+ activation=tf.nn.relu)
+ net = tf.layers.max_pooling2d(net,
+ pool_size=2,
+ strides=2,
+ padding="SAME")
+ net = tf.reshape(net, [-1, 256 * 8 * 8 * 64])
+ logits = tfp.layers.dense_flipout(net, 10)
+ neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits)
+ kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = neg_log_likelihood + kl
+ train_op = tf.train.AdamOptimizer().minimize(loss)
+ ```
+
+ It uses the Flipout gradient estimator to minimize the
+ Kullback-Leibler divergence up to a constant, also known as the
+ negative Evidence Lower Bound. It consists of the sum of two terms:
+ the expected negative log-likelihood, which we approximate via
+ Monte Carlo; and the KL divergence, which is added via regularizer
+ terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
"""
- layer = Conv3DVariational(
+ layer = Conv3DFlipout(
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -1398,6 +2919,7 @@ def conv3d_variational(
bias_posterior_tensor_fn=bias_posterior_tensor_fn,
bias_prior_fn=bias_prior_fn,
bias_divergence_fn=bias_divergence_fn,
+ seed=seed,
name=name,
dtype=inputs.dtype.base_dtype,
_scope=name,
@@ -1407,9 +2929,15 @@ def conv3d_variational(
# Aliases
-Convolution1DVariational = Conv1DVariational
-Convolution2DVariational = Conv2DVariational
-Convolution3DVariational = Conv3DVariational
-convolution1d_variational = conv1d_variational
-convolution2d_variational = conv2d_variational
-convolution3d_variational = conv3d_variational
+Convolution1DReparameterization = Conv1DReparameterization
+Convolution2DReparameterization = Conv2DReparameterization
+Convolution3DReparameterization = Conv3DReparameterization
+convolution1d_reparameterization = conv1d_reparameterization
+convolution2d_reparameterization = conv2d_reparameterization
+convolution3d_reparameterization = conv3d_reparameterization
+Convolution1DFlipout = Conv1DFlipout
+Convolution2DFlipout = Conv2DFlipout
+Convolution3DFlipout = Conv3DFlipout
+convolution1d_flipout = conv1d_flipout
+convolution2d_flipout = conv2d_flipout
+convolution3d_flipout = conv3d_flipout
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py
index a749a396f1..591a8e553d 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py
@@ -13,13 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Dense Bayesian layer using KL-divergence based variational inference.
-
-@@DenseReparameterization
-@@DenseLocalReparameterization
-@@DenseFlipout
-@@dense_reparameterization
-@@dense_local_reparameterization
-@@dense_flipout
"""
from __future__ import absolute_import
@@ -33,25 +26,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as layers_lib
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops.distributions import kullback_leibler as kl_lib
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.distributions import util as distribution_util
-__all__ = [
- "DenseReparameterization",
- "DenseLocalReparameterization",
- "DenseFlipout",
- "dense_reparameterization",
- "dense_local_reparameterization",
- "dense_flipout",
-]
-
-
class _DenseVariational(layers_lib.Layer):
"""Abstract densely-connected class (private, used as implementation base).
@@ -285,6 +266,10 @@ class DenseReparameterization(_DenseVariational):
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the reparameterization estimator [1], which performs a Monte Carlo
+ approximation of the distribution integrating over the `kernel` and
+ `bias`.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -372,6 +357,10 @@ class DenseReparameterization(_DenseVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
def __init__(
@@ -445,6 +434,10 @@ def dense_reparameterization(
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the reparameterization estimator [1], which performs a Monte Carlo
+ approximation of the distribution integrating over the `kernel` and
+ `bias`.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -524,6 +517,10 @@ def dense_reparameterization(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Auto-Encoding Variational Bayes."
+ Diederik P. Kingma, Max Welling.
+ International Conference on Learning Representations, 2014.
"""
layer = DenseReparameterization(
units,
@@ -558,6 +555,10 @@ class DenseLocalReparameterization(_DenseVariational):
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the local reparameterization estimator [1], which performs a
+ Monte Carlo approximation of the distribution on the hidden units
+ induced by the `kernel` and `bias`.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -645,6 +646,10 @@ class DenseLocalReparameterization(_DenseVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Variational Dropout and the Local Reparameterization Trick."
+ Diederik P. Kingma, Tim Salimans, Max Welling.
+ Neural Information Processing Systems, 2015.
"""
def __init__(
@@ -688,7 +693,7 @@ class DenseLocalReparameterization(_DenseVariational):
"`DenseLocalReparameterization` requires "
"`kernel_posterior_fn` produce an instance of "
"`tf.distributions.Independent(tf.distributions.Normal)` "
- "(saw: \"{}\").".format(type(self.kernel_posterior).__name__))
+ "(saw: \"{}\").".format(self.kernel_posterior.name))
self.kernel_posterior_affine = normal_lib.Normal(
loc=self._matmul(inputs, self.kernel_posterior.distribution.loc),
scale=standard_ops.sqrt(self._matmul(
@@ -730,6 +735,10 @@ def dense_local_reparameterization(
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the local reparameterization estimator [1], which performs a
+ Monte Carlo approximation of the distribution on the hidden units
+ induced by the `kernel` and `bias`.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -809,6 +818,10 @@ def dense_local_reparameterization(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Variational Dropout and the Local Reparameterization Trick."
+ Diederik P. Kingma, Tim Salimans, Max Welling.
+ Neural Information Processing Systems, 2015.
"""
layer = DenseLocalReparameterization(
units,
@@ -843,6 +856,12 @@ class DenseFlipout(_DenseVariational):
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the Flipout estimator [1], which performs a Monte Carlo
+ approximation of the distribution integrating over the `kernel` and
+ `bias`. Flipout uses roughly twice as many floating point operations
+ as the reparameterization estimator but has the advantage of
+ significantly lower variance.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -933,6 +952,11 @@ class DenseFlipout(_DenseVariational):
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
"""
def __init__(
@@ -978,7 +1002,7 @@ class DenseFlipout(_DenseVariational):
"`DenseFlipout` requires "
"`kernel_posterior_fn` produce an instance of "
"`tf.distributions.Independent(tf.distributions.Normal)` "
- "(saw: \"{}\").".format(type(self.kernel_posterior).__name__))
+ "(saw: \"{}\").".format(self.kernel_posterior.name))
self.kernel_posterior_affine = normal_lib.Normal(
loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc),
scale=self.kernel_posterior.distribution.scale)
@@ -989,8 +1013,11 @@ class DenseFlipout(_DenseVariational):
input_shape = array_ops.shape(inputs)
batch_shape = input_shape[:-1]
- sign_input = random_sign(input_shape, dtype=inputs.dtype, seed=self.seed)
- sign_output = random_sign(
+ sign_input = layers_util.random_sign(
+ input_shape,
+ dtype=inputs.dtype,
+ seed=self.seed)
+ sign_output = layers_util.random_sign(
array_ops.concat([batch_shape,
array_ops.expand_dims(self.units, 0)], 0),
dtype=inputs.dtype,
@@ -1035,6 +1062,12 @@ def dense_flipout(
outputs = activation(matmul(inputs, kernel) + bias)
```
+ It uses the Flipout estimator [1], which performs a Monte Carlo
+ approximation of the distribution integrating over the `kernel` and
+ `bias`. Flipout uses roughly twice as many floating point operations
+ as the reparameterization estimator but has the advantage of
+ significantly lower variance.
+
The arguments permit separate specification of the surrogate posterior
(`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
distributions.
@@ -1116,6 +1149,11 @@ def dense_flipout(
the expected negative log-likelihood, which we approximate via
Monte Carlo; and the KL divergence, which is added via regularizer
terms which are arguments to the layer.
+
+ [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches."
+ Anonymous. OpenReview, 2017.
+ https://openreview.net/forum?id=rJnpifWAb
"""
layer = DenseFlipout(
units,
@@ -1136,11 +1174,3 @@ def dense_flipout(
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
-
-
-def random_sign(shape, dtype=dtypes.float32, seed=None):
- """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution."""
- random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2,
- dtype=dtypes.int32,
- seed=seed)
- return math_ops.cast(2 * random_bernoulli - 1, dtype)
diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py
index 9a4fecf4e5..8c1fb203f7 100644
--- a/tensorflow/contrib/bayesflow/python/ops/layers_util.py
+++ b/tensorflow/contrib/bayesflow/python/ops/layers_util.py
@@ -23,9 +23,12 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import normal as normal_lib
@@ -178,3 +181,11 @@ def default_mean_field_normal_fn(
return independent_lib.Independent(
dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
return _fn
+
+
+def random_sign(shape, dtype=dtypes.float32, seed=None):
+ """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution."""
+ random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2,
+ dtype=dtypes.int32,
+ seed=seed)
+ return math_ops.cast(2 * random_bernoulli - 1, dtype)
diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD
index 392ac7fa1c..6fdcd0f996 100644
--- a/tensorflow/contrib/boosted_trees/BUILD
+++ b/tensorflow/contrib/boosted_trees/BUILD
@@ -196,6 +196,7 @@ py_test(
name = "quantile_ops_test",
size = "small",
srcs = ["python/kernel_tests/quantile_ops_test.py"],
+ shard_count = 3,
srcs_version = "PY2AND3",
deps = [
":quantile_ops_py",
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 8600c8c53c..88f3006407 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -46,6 +46,7 @@ const char* const kHandleName = "handle";
const char* const kNextStampTokenName = "next_stamp_token";
const char* const kStampTokenName = "stamp_token";
const char* const kAreBucketsReadyName = "are_buckets_ready";
+const char* const kGenerateQuantiles = "generate_quantiles";
// Names for sparse arguments.
const char* const kNumSparseFeaturesName = "num_sparse_features";
const char* const kSparseBucketsName = "sparse_buckets";
@@ -182,6 +183,16 @@ std::vector<float> GenerateBoundaries(const QuantileStream& stream,
return boundaries;
}
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ int num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
+ CHECK_EQ(boundaries.size(), num_quantiles + 1);
+ return boundaries;
+}
+
// Copies quantiles to output list.
void CopyBoundaries(OpKernelContext* const context,
const std::vector<float>& boundaries, const int64 index,
@@ -224,6 +235,8 @@ class CreateQuantileAccumulatorOp : public OpKernel {
OP_REQUIRES_OK(context,
context->GetAttr(kNumQuantilesName, &num_quantiles_));
OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
}
void Compute(OpKernelContext* context) override {
@@ -231,9 +244,9 @@ class CreateQuantileAccumulatorOp : public OpKernel {
// other exceptions. If one already exists, it unrefs the new one.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
- auto result =
- new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_,
- stamp_token_t->scalar<int64>()());
+ auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
+ max_elements_, generate_quantiles_,
+ stamp_token_t->scalar<int64>()());
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
@@ -246,6 +259,7 @@ class CreateQuantileAccumulatorOp : public OpKernel {
// An upperbound on the number of enteries that the summaries might have
// for a feature.
int64 max_elements_;
+ bool generate_quantiles_;
};
REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
@@ -597,10 +611,15 @@ class QuantileAccumulatorFlushOp : public OpKernel {
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << streams_resource->stamp();
QuantileStream* stream = streams_resource->stream(stamp_token);
+ bool generate_quantiles = streams_resource->generate_quantiles();
stream->Finalize();
+
streams_resource->set_boundaries(
stamp_token,
- GenerateBoundaries(*stream, streams_resource->num_quantiles()));
+ generate_quantiles
+ ? GenerateQuantiles(*stream, streams_resource->num_quantiles())
+ : GenerateBoundaries(*stream, streams_resource->num_quantiles()));
+
streams_resource->Reset(next_stamp_token);
}
};
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h
index e1bef02788..3c54868951 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
#include "tensorflow/core/lib/hash/hash.h"
@@ -58,4 +58,4 @@ struct ClassPartitionKey {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h
index 3814edb567..ec4e7c52bb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
#include <unordered_map>
#include <vector>
@@ -79,4 +79,4 @@ class FeatureStatsAccumulator {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h
index aed0d9fdac..37a7103704 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
@@ -50,4 +50,4 @@ class ExamplePartitioner {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h
index 339c2e0fde..382b85cf0b 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h
@@ -13,8 +13,8 @@
// limitations under the License.
//
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
@@ -58,4 +58,4 @@ struct FeatureSplitCandidate {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h
index 34e3ddb777..3dd03215d8 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
#include <math.h>
@@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h
index 642a183aec..cd925f6b65 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/Eigenvalues"
@@ -298,4 +298,4 @@ struct NodeStats {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h
index 054ccd9a8c..81ee2774bd 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h
+++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
#include <string>
@@ -81,4 +81,4 @@ struct SplitStats {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
index ee29a8aa79..cc3dc226cd 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
#include <vector>
@@ -45,4 +45,4 @@ class MultipleAdditiveTrees {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
index 70037d5bd8..804b218f1c 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
#include <algorithm>
#include <unordered_map>
@@ -129,4 +129,4 @@ constexpr decltype(CompareFn())
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h
index fd577ad712..1c4181f1b1 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
#include <memory>
#include <vector>
@@ -322,4 +322,4 @@ WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index c329c6d4f7..aec232f3cb 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#include <cstring>
#include <vector>
@@ -334,4 +334,4 @@ constexpr decltype(CompareFn())
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h
index d95878ec87..b98190b10d 100644
--- a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h
+++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/core/framework/tensor.h"
@@ -42,4 +42,4 @@ void RandomlyInitializeBatchFeatures(
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h
index 5e12429ba7..1838b4cee2 100644
--- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h
+++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
#include <memory>
@@ -72,4 +72,4 @@ class RandomTreeGen {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
index 604ff02744..43526c229a 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
#include "tensorflow/contrib/boosted_trees/lib/utils/example.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT
@@ -46,4 +46,4 @@ class DecisionTree {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
index badc629a11..da5e744851 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h"
@@ -92,4 +92,4 @@ class BatchFeatures {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h
index c3f1c918ca..928bfbfe5c 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
#include <unordered_set>
#include <vector>
@@ -74,4 +74,4 @@ class DropoutUtils {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h
index 54f60e1dee..1371ff337f 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/example.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
#include <algorithm>
#include <unordered_set>
@@ -131,4 +131,4 @@ struct Example {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h
index 5b33c81588..1b654e1c44 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
#include <vector>
@@ -205,4 +205,4 @@ class ExamplesIterable {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/macros.h b/tensorflow/contrib/boosted_trees/lib/utils/macros.h
index 28ea0a4dc1..9a53fb2ef7 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/macros.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/macros.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
#include "tensorflow/core/platform/macros.h"
@@ -23,4 +23,4 @@
return (STATUS); \
}
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h
index c141fe059d..b2166f53d7 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
#include "tensorflow/core/platform/logging.h"
@@ -44,4 +44,4 @@ class OptionalValue {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
index c80431b558..ec06787e1d 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#define TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
#include "tensorflow/core/lib/core/threadpool.h"
@@ -30,4 +30,4 @@ void ParallelFor(int64 batch_size, int64 desired_parallelism,
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
+#endif // TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h
index 6dd55fcacc..546d344f55 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/random.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#define TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
#include "tensorflow/core/lib/random/simple_philox.h"
@@ -36,4 +36,4 @@ inline int32 PoissonBootstrap(random::SimplePhilox* rng) {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
+#endif // TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h
index 9664c9d1c6..87fb1fbf5a 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -127,4 +127,4 @@ class SparseColumnIterable {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h
index 58f5e5a0d1..475d3718ec 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h
+++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -57,4 +57,4 @@ class TensorUtils {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
index 1fa70bafdd..bb57dcf8ae 100644
--- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
@@ -39,6 +39,7 @@ REGISTER_OP("CreateQuantileAccumulator")
.Attr("max_elements: int = 1099511627776") // 1 << 40
.Attr("epsilon: float")
.Attr("num_quantiles: int")
+ .Attr("generate_quantiles: bool=False")
.Input("quantile_accumulator_handle: resource")
.Input("stamp_token: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
index 888d5c57ed..eefa7ef0dc 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
@@ -106,9 +106,11 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
| 6 | 16 | [16, 17, 18, 19, 20, 21]
"""
+ num_quantiles = 3
with self.test_session() as sess:
accumulator = quantile_ops.QuantileAccumulator(
- init_stamp_token=0, num_quantiles=3, epsilon=0.001, name="q1")
+ init_stamp_token=0, num_quantiles=num_quantiles,
+ epsilon=0.001, name="q1")
resources.initialize_resources(resources.shared_resources()).run()
input_column = array_ops.placeholder(dtypes.float32)
weights = array_ops.placeholder(dtypes.float32)
@@ -131,8 +133,104 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
buckets, are_ready_flush = (sess.run(
[buckets, are_ready_flush]))
self.assertEqual(True, are_ready_flush)
+ self.assertEqual(num_quantiles + 1, len(buckets))
self.assertAllEqual([1, 86., 170., 253.], buckets)
+ def testStreamingQuantileBucketsLowPrecisionInput(self):
+ """Tests inputs that simulate low precision float16 values."""
+
+ num_quantiles = 3
+ # set generate_quantiles to True since the test will generate fewer
+ # boundaries otherwise.
+ with self.test_session() as sess:
+ accumulator = quantile_ops.QuantileAccumulator(
+ init_stamp_token=0, num_quantiles=num_quantiles,
+ epsilon=0.001, name="q1", generate_quantiles=True)
+ resources.initialize_resources(resources.shared_resources()).run()
+ input_column = array_ops.placeholder(dtypes.float32)
+ weights = array_ops.placeholder(dtypes.float32)
+ update = accumulator.add_summary(
+ stamp_token=0,
+ column=input_column,
+ example_weights=weights)
+
+ with self.test_session() as sess:
+ # This input is generated by integer in the range [2030, 2060]
+ # but represented by with float16 precision. Integers <= 2048 are
+ # exactly represented, whereas numbers > 2048 are rounded; and hence
+ # numbers > 2048 are repeated. For precision loss / rounding, see:
+ # https://en.wikipedia.org/wiki/Half-precision_floating-point_format.
+ #
+ # The intent of the test is not handling of float16 values, but to
+ # validate the number of buckets is returned, in cases where the input
+ # may contain repeated values.
+ inputs = [
+ 2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0, 2036.0, 2037.0,
+ 2038.0, 2039.0, 2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0,
+ 2046.0, 2047.0, 2048.0, 2048.0, 2050.0, 2052.0, 2052.0, 2052.0,
+ 2054.0, 2056.0, 2056.0, 2056.0, 2058.0, 2060.0
+ ]
+ sess.run(update,
+ {input_column: inputs,
+ weights: [1] * len(inputs)})
+
+ with self.test_session() as sess:
+ sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
+ are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
+ buckets, are_ready_flush = (sess.run(
+ [buckets, are_ready_flush]))
+ self.assertEqual(True, are_ready_flush)
+ self.assertEqual(num_quantiles + 1, len(buckets))
+ self.assertAllEqual([2030, 2040, 2050, 2060], buckets)
+
+ def _testStreamingQuantileBucketsHelper(self, inputs):
+ """Helper to test quantile buckets on different inputs."""
+
+ # Use 3 quantiles, 4 boundaries for simplicity.
+ num_quantiles = 3
+ # set generate_quantiles to True since the test will generate fewer
+ # boundaries otherwise.
+ with self.test_session() as sess:
+ accumulator = quantile_ops.QuantileAccumulator(
+ init_stamp_token=0, num_quantiles=num_quantiles,
+ epsilon=0.001, name="q1", generate_quantiles=True)
+ resources.initialize_resources(resources.shared_resources()).run()
+ input_column = array_ops.placeholder(dtypes.float32)
+ weights = array_ops.placeholder(dtypes.float32)
+ update = accumulator.add_summary(
+ stamp_token=0,
+ column=input_column,
+ example_weights=weights)
+
+ with self.test_session() as sess:
+ sess.run(update,
+ {input_column: inputs,
+ weights: [1] * len(inputs)})
+
+ with self.test_session() as sess:
+ sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
+ are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
+ buckets, are_ready_flush = (sess.run(
+ [buckets, are_ready_flush]))
+ self.assertEqual(True, are_ready_flush)
+ self.assertEqual(num_quantiles + 1, len(buckets))
+
+ def testStreamingQuantileBucketsRepeatedSingleValue(self):
+ inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ self._testStreamingQuantileBucketsHelper(inputs)
+
+ def testStreamingQ2antileBucketsRepeatedTwoValues(self):
+ inputs = [1, 1, 1, 2, 2, 2, 2, 2, 1, 1]
+ self._testStreamingQuantileBucketsHelper(inputs)
+
+ def testStreamingQ2antileBucketsRepeatedTwoValuesUnbalanced(self):
+ inputs = [7, 7, 7, 2, 7, 7, 2, 2, 7, 7]
+ self._testStreamingQuantileBucketsHelper(inputs)
+
+ def testStreamingQuantileBucketsFewerInputstThanBuckets(self):
+ inputs = [5]
+ self._testStreamingQuantileBucketsHelper(inputs)
+
def testStreamingQuantileBuckets(self):
"""Sets up the quantile summary op test as follows.
diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
index 23168bf493..b281a4c6d1 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
@@ -104,7 +104,7 @@ def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device):
batched_ops = collections.defaultdict(list)
# Group the ops by their batching_key. Ops that share the same batching key
# can be executed together.
- for handler in per_handler_ops.keys():
+ for handler in sorted(per_handler_ops.keys()):
for op in per_handler_ops[handler]:
batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op)
op_results = {}
diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
index 294e04002a..97d57e8b23 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
@@ -47,7 +47,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
num_quantiles,
max_elements=None,
name=None,
- container=None):
+ container=None,
+ generate_quantiles=False):
"""Creates a QuantileAccumulator object.
Args:
@@ -57,8 +58,11 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
max_elements: Maximum number of elements added to the accumulator.
name: the name to save the accumulator under.
container: An optional `string`. Defaults to `""`
+ generate_quantiles: Generate quantiles instead of approximate boundaries.
+ If true, exactly `num_quantiles` will be produced in the final summary.
"""
self._epsilon = epsilon
+ self._generate_quantiles = generate_quantiles
name = _PATTERN.sub("", name)
with ops.name_scope(name, "QuantileAccumulator") as name:
@@ -70,7 +74,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
init_stamp_token,
epsilon=epsilon,
max_elements=max_elements,
- num_quantiles=num_quantiles)
+ num_quantiles=num_quantiles,
+ generate_quantiles=generate_quantiles)
is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized(
self._quantile_accumulator_handle)
resources.register_resource(self._quantile_accumulator_handle,
@@ -176,7 +181,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
summaries=summary)
def flush(self, stamp_token, next_stamp_token):
- """Finalizes quantile summary stream and resets it for next iteration."""
+ """Finalizes quantile summary stream and resets it for next iteration.
+
+ Args:
+ stamp_token: Exepcted current token.
+ next_stamp_token: Next value for the token.
+ Returns:
+ A list of quantiles or approximate boundaries.
+ """
return gen_quantile_ops.quantile_accumulator_flush(
quantile_accumulator_handle=self._quantile_accumulator_handle,
stamp_token=stamp_token,
diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
index ad9c8961aa..3ebf28ea44 100644
--- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
+++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
@@ -179,4 +179,4 @@ class DecisionTreeEnsembleResource : public StampedResource {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_
diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h
index fb29f79e57..fdaaae7f47 100644
--- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h
+++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT
@@ -32,12 +32,14 @@ using QuantileStream =
class QuantileStreamResource : public StampedResource {
public:
QuantileStreamResource(const float epsilon, const int32 num_quantiles,
- const int64 max_elements, int64 stamp_token)
+ const int64 max_elements, bool generate_quantiles,
+ int64 stamp_token)
: stream_(epsilon, max_elements),
are_buckets_ready_(false),
epsilon_(epsilon),
num_quantiles_(num_quantiles),
- max_elements_(max_elements) {
+ max_elements_(max_elements),
+ generate_quantiles_(generate_quantiles) {
set_stamp(stamp_token);
}
@@ -74,6 +76,11 @@ class QuantileStreamResource : public StampedResource {
are_buckets_ready_ = are_buckets_ready;
}
+ bool generate_quantiles() const { return generate_quantiles_; }
+ void set_generate_quantiles(bool generate_quantiles) {
+ generate_quantiles_ = generate_quantiles;
+ }
+
private:
~QuantileStreamResource() override {}
@@ -95,10 +102,15 @@ class QuantileStreamResource : public StampedResource {
const int32 num_quantiles_;
// An upper-bound for the number of elements.
int64 max_elements_;
+
+ // Generate quantiles instead of approximate boundaries.
+ // If true, exactly `num_quantiles` will be produced in the final summary.
+ bool generate_quantiles_;
+
TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource);
};
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h
index aabeeb9851..957bbe8d61 100644
--- a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h
+++ b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
+#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
+#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/mutex.h"
@@ -39,4 +39,4 @@ class StampedResource : public ResourceBase {
} // namespace boosted_trees
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
+#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
index deb324634b..1bfd27305d 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
namespace tensorflow {
-
namespace {
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
@@ -40,33 +39,6 @@ Status ParseJson(StringPiece json, Json::Value* result) {
return Status::OK();
}
-string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) {
- switch (enum_type) {
- case BigQueryTableAccessor::ColumnType::kRecord:
- return "RECORD";
- case BigQueryTableAccessor::ColumnType::kString:
- return "STRING";
- case BigQueryTableAccessor::ColumnType::kBytes:
- return "BYTES";
- case BigQueryTableAccessor::ColumnType::kInteger:
- return "INTEGER";
- case BigQueryTableAccessor::ColumnType::kFloat:
- return "FLOAT";
- case BigQueryTableAccessor::ColumnType::kBoolean:
- return "BOOLEAN";
- case BigQueryTableAccessor::ColumnType::kTimestamp:
- return "TIMESTAMP";
- case BigQueryTableAccessor::ColumnType::kDate:
- return "DATE";
- case BigQueryTableAccessor::ColumnType::kTime:
- return "TIME";
- case BigQueryTableAccessor::ColumnType::kDatetime:
- return "DATETIME";
- case BigQueryTableAccessor::ColumnType::kNone:
- return "NONE";
- }
-}
-
Status ParseColumnType(const string& type,
BigQueryTableAccessor::ColumnType* enum_type) {
if (type == "RECORD") {
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
index 7d0eee59ae..b349063715 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
#include <map>
#include <memory>
@@ -205,4 +205,4 @@ class BigQueryTableAccessor {
};
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
index b2b11f4f57..59f2333298 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
#include <string>
@@ -401,4 +401,4 @@ const string kTestEmptyRow = R"({
} // namespace
} // namepsace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
+#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 817e96f5da..12bfd3c62b 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -134,6 +134,9 @@ if(WIN32)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0")
+
+ # Try to avoid flaky failures due to failed generation of generate.stamp files.
+ set(CMAKE_SUPPRESS_REGENERATION ON)
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake
index 013b3a862f..fd57734298 100644
--- a/tensorflow/contrib/cmake/external/snappy.cmake
+++ b/tensorflow/contrib/cmake/external/snappy.cmake
@@ -47,4 +47,4 @@ ExternalProject_Add(snappy
)
# actually enables snappy in the source code
-add_definitions(-DTF_USE_SNAPPY) \ No newline at end of file
+add_definitions(-DTF_USE_SNAPPY)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index e37d059a84..9ce8b3cc9c 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -1,3 +1,5 @@
+# python_sanity_test.py will complain about invalid or missing entries
+# problematic entries can be commented for temporary whitelisting
tensorflow
tensorflow/core
tensorflow/core/example
@@ -31,9 +33,11 @@ tensorflow/python/grappler
tensorflow/python/keras
tensorflow/python/keras/activations
tensorflow/python/keras/applications
+tensorflow/python/keras/applications/densenet
tensorflow/python/keras/applications/inception_resnet_v2
tensorflow/python/keras/applications/inception_v3
tensorflow/python/keras/applications/mobilenet
+tensorflow/python/keras/applications/nasnet
tensorflow/python/keras/applications/resnet50
tensorflow/python/keras/applications/vgg16
tensorflow/python/keras/applications/vgg19
@@ -109,7 +113,6 @@ tensorflow/contrib/android/java/org/tensorflow/contrib
tensorflow/contrib/android/java/org/tensorflow/contrib/android
tensorflow/contrib/android/jni
tensorflow/contrib/batching
-tensorflow/contrib/batching/kernels
tensorflow/contrib/batching/python
tensorflow/contrib/batching/python/ops
tensorflow/contrib/bayesflow
@@ -308,6 +311,8 @@ tensorflow/contrib/metrics
tensorflow/contrib/metrics/python
tensorflow/contrib/metrics/python/metrics
tensorflow/contrib/metrics/python/ops
+tensorflow/contrib/mpi_collectives/python
+tensorflow/contrib/mpi_collectives/python/ops
tensorflow/contrib/model_pruning
tensorflow/contrib/model_pruning/examples
tensorflow/contrib/model_pruning/examples/cifar10
diff --git a/tensorflow/contrib/cmake/python_sanity_test.py b/tensorflow/contrib/cmake/python_sanity_test.py
new file mode 100644
index 0000000000..e0056823a8
--- /dev/null
+++ b/tensorflow/contrib/cmake/python_sanity_test.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Complain about invalid or missing entries in python_*.txt files.
+
+Problematic entries can be commented for temporary whitelisting.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import unittest
+
+
+def abs_path(path):
+ root = os.path.dirname(__file__)
+
+ for _ in range(3):
+ root = os.path.join(root, os.pardir)
+
+ path = os.path.join(root, path)
+ path = os.path.abspath(path)
+ return path
+
+
+def read_entries(test):
+ with open(abs_path(test.entries_file), "r") as f:
+ lines = f.readlines()
+
+ lines = [line.strip() for line in lines]
+ lines = [line for line in lines if line]
+
+ test.entries = []
+ test.whitelist = []
+
+ for line in lines:
+ # line is comment
+ if line.startswith("#"):
+ line = line[1:].strip()
+ # whitelist entry
+ if line.startswith("tensorflow/"):
+ test.whitelist.append(line)
+ # line has comment -> strip comment
+ elif line.find("#") != -1:
+ line = line[:line.find("#")].strip()
+ test.entries.append(line)
+ else:
+ test.entries.append(line)
+
+
+def test_invalid_directories(test):
+ for entry in test.entries:
+ if not os.path.isdir(abs_path(entry)):
+ problem = "'" + test.entries_file + "' contains invalid '" + entry + "'"
+ solution = ("Please remove the invalid entry (or add the missing "
+ "directory).")
+ raise AssertionError(problem + "\n" + solution)
+
+
+def test_missing_directory(test, path):
+ if path in test.whitelist:
+ return
+
+ dir_exists = os.path.isdir(abs_path(path))
+ entry_exists = path in test.entries
+
+ if dir_exists and not entry_exists:
+ problem = "'" + test.entries_file + "' is missing '" + path + "'"
+ solution = "Please add the missing entry (comment to whitelist if needed)."
+ raise AssertionError(problem + "\n" + solution)
+
+
+class PythonModuleTest(unittest.TestCase):
+
+ def setUp(self):
+ self.entries_file = "tensorflow/contrib/cmake/python_modules.txt"
+ read_entries(self)
+
+ def testInvalidEntries(self):
+ test_invalid_directories(self)
+
+ def testMissingModules(self):
+ module_names = next(os.walk(abs_path("tensorflow/contrib")))[1]
+
+ for module_name in module_names:
+ path = "tensorflow/contrib/" + module_name
+
+ test_missing_directory(self, path + "/python")
+ test_missing_directory(self, path + "/python/ops")
+ test_missing_directory(self, path + "/python/kernels")
+ test_missing_directory(self, path + "/python/layers")
+
+
+class PythonProtoTest(unittest.TestCase):
+
+ def setUp(self):
+ self.entries_file = "tensorflow/contrib/cmake/python_protos.txt"
+ read_entries(self)
+
+ def testInvalidEntries(self):
+ test_invalid_directories(self)
+
+
+class PythonProtoCCTest(unittest.TestCase):
+
+ def setUp(self):
+ self.entries_file = "tensorflow/contrib/cmake/python_protos_cc.txt"
+ read_entries(self)
+
+ def testInvalidEntries(self):
+ test_invalid_directories(self)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 24d7fb82a2..129c208ecd 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -126,7 +126,9 @@ endfunction()
file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto"
+ "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto"
)
+
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
${tensorflow_source_dir} ${tf_protos_cc_srcs}
)
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 6f56e9d086..138993db35 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -15,6 +15,7 @@
set(tf_op_lib_names
"audio_ops"
"array_ops"
+ "batch_ops"
"bitwise_ops"
"candidate_sampling_ops"
"checkpoint_ops"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 17bbdb1a86..8862390d2b 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -126,7 +126,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}")
STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}")
foreach(python_proto ${python_protos})
- if(NOT python_proto MATCHES "\#")
+ if(NOT python_proto MATCHES "^\#")
+ STRING(REGEX REPLACE " *\#.*" "" python_proto "${python_proto}")
if(NOT EXISTS "${tensorflow_source_dir}/${python_proto}")
message(SEND_ERROR "Python proto directory not found: ${python_proto}")
endif()
@@ -147,7 +148,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}")
STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}")
foreach(python_proto_cc ${python_protos_cc})
- if(NOT python_proto_cc MATCHES "\#")
+ if(NOT python_proto_cc MATCHES "^\#")
+ STRING(REGEX REPLACE " *\#.*" "" python_proto_cc "${python_proto_cc}")
if(NOT EXISTS "${tensorflow_source_dir}/${python_proto_cc}")
message(SEND_ERROR "Python proto CC directory not found: ${python_proto_cc}")
endif()
@@ -209,7 +211,8 @@ STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}")
STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}")
foreach(python_module ${python_modules})
- if(NOT python_module MATCHES "\#")
+ if(NOT python_module MATCHES "^\#")
+ STRING(REGEX REPLACE " *\#.*" "" python_module "${python_module}")
if(NOT EXISTS "${tensorflow_source_dir}/${python_module}")
message(SEND_ERROR "Python module not found: ${python_module}")
endif()
@@ -314,6 +317,7 @@ endfunction()
GENERATE_PYTHON_OP_LIB("audio_ops")
GENERATE_PYTHON_OP_LIB("array_ops")
+GENERATE_PYTHON_OP_LIB("batch_ops")
GENERATE_PYTHON_OP_LIB("bitwise_ops")
GENERATE_PYTHON_OP_LIB("math_ops")
GENERATE_PYTHON_OP_LIB("functional_ops")
diff --git a/tensorflow/contrib/coder/kernels/range_coder.h b/tensorflow/contrib/coder/kernels/range_coder.h
index c24fb707fc..f46413072e 100644
--- a/tensorflow/contrib/coder/kernels/range_coder.h
+++ b/tensorflow/contrib/coder/kernels/range_coder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
+#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
+#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
#include <limits>
#include <string>
@@ -106,4 +106,4 @@ class RangeDecoder {
};
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
+#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_
diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h
index 95241a8682..b8aabcef62 100644
--- a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h
+++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
+#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
#include <vector>
@@ -30,4 +30,4 @@ Status MergeAxes(const TensorShape& broadcast_shape,
std::vector<int64>* merged_storage_shape_pointer);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index bae66ffd42..b806799202 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -35,10 +35,10 @@ from tensorflow.python.ops.variables import Variable
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
-__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"]
+__all__ = ['copy_op_to_graph', 'copy_variable_to_graph', 'get_copied_op']
-def copy_variable_to_graph(org_instance, to_graph, scope=""):
+def copy_variable_to_graph(org_instance, to_graph, scope=''):
"""Given a `Variable` instance from one `Graph`, initializes and returns
a copy of it from another `Graph`, under the specified scope
(default `""`).
@@ -56,12 +56,11 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
"""
if not isinstance(org_instance, Variable):
- raise TypeError(str(org_instance) + " is not a Variable")
+ raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
- if scope != "":
- new_name = (scope + '/' +
- org_instance.name[:org_instance.name.index(':')])
+ if scope != '':
+ new_name = (scope + '/' + org_instance.name[:org_instance.name.index(':')])
else:
new_name = org_instance.name[:org_instance.name.index(':')]
@@ -73,15 +72,15 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
for name, collection in org_instance.graph._collections.items():
if org_instance in collection:
if (name == ops.GraphKeys.GLOBAL_VARIABLES or
- name == ops.GraphKeys.TRAINABLE_VARIABLES or
- scope == ''):
+ name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''):
collections.append(name)
else:
collections.append(scope + '/' + name)
#See if its trainable.
- trainable = (org_instance in org_instance.graph.get_collection(
- ops.GraphKeys.TRAINABLE_VARIABLES))
+ trainable = (
+ org_instance in org_instance.graph.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES))
#Get the initial value
with org_instance.graph.as_default():
temp_session = Session()
@@ -89,17 +88,17 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(init_value,
- trainable,
- name=new_name,
- collections=collections,
- validate_shape=False)
+ new_var = Variable(
+ init_value,
+ trainable,
+ name=new_name,
+ collections=collections,
+ validate_shape=False)
return new_var
-def copy_op_to_graph(org_instance, to_graph, variables,
- scope=""):
+def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
"""Returns a copy of an operation from another Graph under a specified scope.
Given an `Operation` `org_instance` from one `Graph`,
@@ -139,14 +138,12 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If a variable by the new name already exists, return the
#correspondng tensor that will act as an input
if new_name in copied_variables:
- return to_graph.get_tensor_by_name(
- copied_variables[new_name].name)
+ return to_graph.get_tensor_by_name(copied_variables[new_name].name)
#If an instance of the same name exists, return appropriately
try:
- already_present = to_graph.as_graph_element(new_name,
- allow_tensor=True,
- allow_operation=True)
+ already_present = to_graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)
return already_present
except:
pass
@@ -184,20 +181,21 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If it has an original_op parameter, copy it
if op._original_op is not None:
- new_original_op = copy_op_to_graph(op._original_op, to_graph,
- variables, scope)
+ new_original_op = copy_op_to_graph(op._original_op, to_graph, variables,
+ scope)
else:
new_original_op = None
#If it has control inputs, call this function recursively on each.
- new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.control_inputs]
+ new_control_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope)
+ for x in op.control_inputs
+ ]
#If it has inputs, call this function recursively on each.
- new_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.inputs]
+ new_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs
+ ]
#Make a new node_def based on that of the original.
#An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
@@ -216,13 +214,8 @@ def copy_op_to_graph(org_instance, to_graph, variables,
op_def = deepcopy(op._op_def)
#Initialize a new Operation instance
- new_op = ops.Operation(new_node_def,
- to_graph,
- new_inputs,
- output_types,
- new_control_inputs,
- input_types,
- new_original_op,
+ new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types,
+ new_control_inputs, input_types, new_original_op,
op_def)
#Use Graph's hidden methods to add the op
to_graph._add_op(new_op) # pylint: disable=protected-access
@@ -233,10 +226,10 @@ def copy_op_to_graph(org_instance, to_graph, variables,
return new_op
else:
- raise TypeError("Could not copy instance: " + str(org_instance))
+ raise TypeError('Could not copy instance: ' + str(org_instance))
-def get_copied_op(org_instance, graph, scope=""):
+def get_copied_op(org_instance, graph, scope=''):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
@@ -259,5 +252,5 @@ def get_copied_op(org_instance, graph, scope=""):
else:
new_name = org_instance.name
- return graph.as_graph_element(new_name, allow_tensor=True,
- allow_operation=True)
+ return graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 1fbf18f30a..1cf0202fd8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -475,7 +475,7 @@ py_test(
py_test(
name = "stats_dataset_ops_test",
- size = "small",
+ size = "medium",
srcs = ["stats_dataset_ops_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 69252612a8..dd8247bfd4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -765,6 +765,15 @@ class MapDatasetSerializationTest(
self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var))
+
+ self.run_core_tests(_build_ds, None, 10)
+
def testCaptureDefunInMapFn(self):
num_outputs = 100
@@ -856,6 +865,15 @@ class ParallelMapDatasetSerializationTest(
self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var))
+
+ self.run_core_tests(_build_ds, None, 10)
+
def testCaptureDefunInMapFn(self):
num_outputs = 100
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 95848af699..7f510c4221 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -129,6 +129,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "autoregressive_test",
+ size = "small",
+ srcs = ["python/kernel_tests/autoregressive_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "binomial_test",
size = "small",
srcs = ["python/kernel_tests/binomial_test.py"],
@@ -919,6 +932,22 @@ cuda_py_test(
)
cuda_py_test(
+ name = "real_nvp_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/real_nvp_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "permute_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/permute_test.py"],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 7b401e178f..60a187e541 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -23,6 +23,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.contrib.distributions.python.ops.autoregressive import *
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.cauchy import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
@@ -84,6 +85,7 @@ from tensorflow.python.ops.distributions.uniform import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
+ 'auto_correlation',
'bijectors',
'Cauchy',
'ConditionalDistribution',
@@ -92,6 +94,7 @@ _allowed_symbols = [
'NOT_REPARAMETERIZED',
'ReparameterizationType',
'Distribution',
+ 'Autoregressive',
'Binomial',
'Bernoulli',
'BernoulliWithSigmoidProbs',
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
new file mode 100644
index 0000000000..0928dc3f35
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py
@@ -0,0 +1,94 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import autoregressive as autoregressive_lib
+from tensorflow.contrib.distributions.python.ops import independent as independent_lib
+from tensorflow.contrib.distributions.python.ops import test_util
+from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
+from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.platform import test
+
+
+class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase):
+ """Tests the Autoregressive distribution."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def _random_scale_tril(self, event_size):
+ n = np.int32(event_size * (event_size + 1) // 2)
+ p = 2. * self._rng.random_sample(n).astype(np.float32) - 1.
+ return distribution_util.fill_triangular(0.25 * p)
+
+ def _normal_fn(self, affine_bijector):
+ def _fn(samples):
+ scale = math_ops.exp(affine_bijector.forward(samples))
+ return independent_lib.Independent(
+ normal_lib.Normal(loc=0., scale=scale, validate_args=True),
+ reinterpreted_batch_ndims=1)
+ return _fn
+
+ def testSampleAndLogProbConsistency(self):
+ batch_shape = []
+ event_size = 2
+ with self.test_session() as sess:
+ batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
+ sample0 = array_ops.zeros(batch_event_shape)
+ affine = Affine(scale_tril=self._random_scale_tril(event_size))
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(affine), sample0, validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess.run, ar, radius=1., center=0., rtol=0.01)
+
+ def testCompareToBijector(self):
+ """Demonstrates equivalence between TD, Bijector approach and AR dist."""
+ sample_shape = np.int32([4, 5])
+ batch_shape = np.int32([])
+ event_size = np.int32(2)
+ with self.test_session() as sess:
+ batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
+ sample0 = array_ops.zeros(batch_event_shape)
+ affine = Affine(scale_tril=self._random_scale_tril(event_size))
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(affine), sample0, validate_args=True)
+ ar_flow = MaskedAutoregressiveFlow(
+ is_constant_jacobian=True,
+ shift_and_log_scale_fn=lambda x: [None, affine.forward(x)],
+ validate_args=True)
+ td = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=ar_flow,
+ event_shape=[event_size],
+ batch_shape=batch_shape,
+ validate_args=True)
+ x_shape = np.concatenate(
+ [sample_shape, batch_shape, [event_size]], axis=0)
+ x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1.
+ td_log_prob_, ar_log_prob_ = sess.run([td.log_prob(x), ar.log_prob(x)])
+ self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
index 288d9d8dd6..dcfb0eb051 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
@@ -149,5 +149,17 @@ class MaskedAutoregressiveFlowShiftOnlyTest(MaskedAutoregressiveFlowTest):
}
+class MaskedAutoregressiveFlowUnrollLoopTest(MaskedAutoregressiveFlowTest):
+
+ @property
+ def _autoregressive_flow_kwargs(self):
+ return {
+ "shift_and_log_scale_fn": masked_autoregressive_default_template(
+ hidden_layers=[2], shift_only=False),
+ "is_constant_jacobian": False,
+ "unroll_loop": True,
+ }
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
new file mode 100644
index 0000000000..46fe779741
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
@@ -0,0 +1,144 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MaskedAutoregressiveFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.distributions.python.ops import test_util
+from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert
+from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import real_nvp_default_template
+from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import RealNVP
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
+from tensorflow.python.platform import test
+
+
+class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
+
+ @property
+ def _real_nvp_kwargs(self):
+ return {
+ "shift_and_log_scale_fn": real_nvp_default_template(
+ hidden_layers=[3], shift_only=False),
+ "is_constant_jacobian": False,
+ }
+
+ def testBijector(self):
+ x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2)
+ with self.test_session() as sess:
+ nvp = RealNVP(
+ num_masked=4,
+ validate_args=True,
+ **self._real_nvp_kwargs)
+ x = constant_op.constant(x_)
+ forward_x = nvp.forward(x)
+ # Use identity to invalidate cache.
+ inverse_y = nvp.inverse(array_ops.identity(forward_x))
+ fldj = nvp.forward_log_det_jacobian(x)
+ # Use identity to invalidate cache.
+ ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x))
+ variables.global_variables_initializer().run()
+ [
+ forward_x_,
+ inverse_y_,
+ ildj_,
+ fldj_,
+ ] = sess.run([
+ forward_x,
+ inverse_y,
+ ildj,
+ fldj,
+ ])
+ self.assertEqual("real_nvp", nvp.name)
+ self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.)
+ self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.)
+ self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
+
+ def testMutuallyConsistent(self):
+ dims = 4
+ with self.test_session() as sess:
+ nvp = RealNVP(
+ num_masked=3,
+ validate_args=True,
+ **self._real_nvp_kwargs)
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=nvp,
+ event_shape=[dims],
+ validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess_run_fn=sess.run,
+ dist=dist,
+ num_samples=int(1e5),
+ radius=1.,
+ center=0.,
+ rtol=0.02)
+
+ def testInvertMutuallyConsistent(self):
+ dims = 4
+ with self.test_session() as sess:
+ nvp = Invert(RealNVP(
+ num_masked=3,
+ validate_args=True,
+ **self._real_nvp_kwargs))
+ dist = transformed_distribution_lib.TransformedDistribution(
+ distribution=normal_lib.Normal(loc=0., scale=1.),
+ bijector=nvp,
+ event_shape=[dims],
+ validate_args=True)
+ self.run_test_sample_consistent_log_prob(
+ sess_run_fn=sess.run,
+ dist=dist,
+ num_samples=int(1e5),
+ radius=1.,
+ center=0.,
+ rtol=0.02)
+
+
+class NICETest(RealNVPTest):
+
+ @property
+ def _real_nvp_kwargs(self):
+ return {
+ "shift_and_log_scale_fn": real_nvp_default_template(
+ hidden_layers=[2], shift_only=True),
+ "is_constant_jacobian": True,
+ }
+
+
+class RealNVPConstantShiftScaleTest(RealNVPTest):
+
+ @property
+ def _real_nvp_kwargs(self):
+
+ def constant_shift_log_scale_fn(x0, output_units):
+ del x0, output_units
+ shift = constant_op.constant([0.1])
+ log_scale = constant_op.constant([0.5])
+ return shift, log_scale
+
+ return {
+ "shift_and_log_scale_fn": constant_shift_log_scale_fn,
+ "is_constant_jacobian": True,
+ }
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
index 49451446b5..e216d88cb1 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
@@ -22,12 +22,15 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test
+@test_util.with_c_api
class _ReshapeBijectorTest(object):
"""Base class for testing the reshape transformation.
@@ -136,7 +139,8 @@ class _ReshapeBijectorTest(object):
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)
- def testInvalidDimensionsOpError(self):
+ # pylint: disable=invalid-name
+ def _testInvalidDimensionsOpError(self, expected_error_message):
with self.test_session() as sess:
@@ -146,10 +150,10 @@ class _ReshapeBijectorTest(object):
event_shape_in=shape_in,
validate_args=True)
- with self.assertRaisesError(
- "elements must be either positive integers or `-1`."):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward_event_shape_tensor(shape_in),
feed_dict=feed_dict)
+ # pylint: enable=invalid-name
def testValidButNonMatchingInputOpError(self):
x = np.random.randn(4, 3, 2)
@@ -184,7 +188,8 @@ class _ReshapeBijectorTest(object):
sess.run(bijector.forward(x),
feed_dict=feed_dict)
- def testInputOutputMismatchOpError(self):
+ # pylint: disable=invalid-name
+ def _testInputOutputMismatchOpError(self, expected_error_message):
x1 = np.random.randn(4, 2, 3)
x2 = np.random.randn(4, 1, 1, 5)
@@ -196,13 +201,11 @@ class _ReshapeBijectorTest(object):
event_shape_in=shape_in,
validate_args=True)
- # test that *all* methods check basic assertions
- with self.assertRaisesError(
- "Input to reshape is a tensor with"):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
- with self.assertRaisesError(
- "Input to reshape is a tensor with"):
+ with self.assertRaisesError(expected_error_message):
sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
+ # pylint: enable=invalid-name
def testOneShapePartiallySpecified(self):
expected_x = np.random.randn(4, 6)
@@ -262,6 +265,7 @@ class _ReshapeBijectorTest(object):
raise NotImplementedError("Subclass failed to implement `build_shapes`.")
+@test_util.with_c_api
class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
@@ -299,7 +303,22 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
validate_args=True)
assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
+ def testInvalidDimensionsOpError(self):
+ if ops._USE_C_API:
+ error_message = "Invalid value in tensor used for shape: -2"
+ else:
+ error_message = "elements must be either positive integers or `-1`."
+ self._testInvalidDimensionsOpError(error_message)
+
+ def testInputOutputMismatchOpError(self):
+ if ops._USE_C_API:
+ error_message = "Cannot reshape a tensor with"
+ else:
+ error_message = "Input to reshape is a tensor with"
+ self._testInputOutputMismatchOpError(error_message)
+
+@test_util.with_c_api
class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
@@ -313,7 +332,15 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
def assertRaisesError(self, msg):
return self.assertRaisesOpError(msg)
+ def testInvalidDimensionsOpError(self):
+ self._testInvalidDimensionsOpError(
+ "elements must be either positive integers or `-1`.")
+
+ def testInputOutputMismatchOpError(self):
+ self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
+@test_util.with_c_api
class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
def build_shapes(self, shape_in, shape_out):
@@ -325,6 +352,13 @@ class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
def assertRaisesError(self, msg):
return self.assertRaisesOpError(msg)
+ def testInvalidDimensionsOpError(self):
+ self._testInvalidDimensionsOpError(
+ "elements must be either positive integers or `-1`.")
+
+ def testInputOutputMismatchOpError(self):
+ self._testInputOutputMismatchOpError("Input to reshape is a tensor with")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
index d292b04665..04f047aa0c 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py
@@ -27,6 +27,8 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
from tensorflow.python.platform import test
+rng = np.random.RandomState(0)
+
class VectorDiffeomixtureTest(
test_util.VectorDistributionTestHelpers, test.TestCase):
@@ -37,7 +39,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -66,7 +68,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(1., 1.5),
loc=[
None,
@@ -95,7 +97,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [1.]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -122,12 +124,39 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.01)
+ def testSampleProbConsistentBroadcastMixTwoBatchDims(self):
+ dims = 4
+ loc_1 = rng.randn(2, 3, dims).astype(np.float32)
+
+ with self.test_session() as sess:
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32),
+ temperature=[1.],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ None,
+ loc_1,
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=[np.float32(1.1)],
+ is_positive_definite=True),
+ ] * 2,
+ validate_args=True)
+ # Ball centered at component0's mean.
+ self.run_test_sample_consistent_log_prob(
+ sess.run, vdm, radius=2., center=0., rtol=0.01)
+ # Larger ball centered at component1's mean.
+ self.run_test_sample_consistent_log_prob(
+ sess.run, vdm, radius=3., center=loc_1, rtol=0.02)
+
def testMeanCovarianceNoBatch(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[1 / 10.],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([-2.]),
@@ -147,12 +176,94 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_mean_covariance(
sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+ def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self):
+ # As temperature decreases, this should approach a mixture of normals, with
+ # components at -2, 2.
+ with self.test_session() as sess:
+ dims = 1
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=[0.],
+ temperature=[[2.], [1.], [0.2]],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ np.float32([-2.]),
+ np.float32([2.]),
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=np.float32(0.5),
+ is_positive_definite=True),
+ ] * 2, # Use the same scale for each component.
+ quadrature_size=8,
+ validate_args=True)
+
+ samps = vdm.sample(10000)
+ self.assertAllEqual((10000, 3, 1), samps.shape)
+ samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
+
+ # One characteristic of a discrete mixture (as opposed to a "smear") is
+ # that more weight is put near the component centers at -2, 2, and thus
+ # less weight is put near the origin.
+ prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0)
+ self.assertGreater(
+ prob_of_being_near_origin[0], prob_of_being_near_origin[1])
+ self.assertGreater(
+ prob_of_being_near_origin[1], prob_of_being_near_origin[2])
+
+ # Run this test as well, just because we can.
+ self.run_test_sample_consistent_mean_covariance(
+ sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+
+ def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self):
+ with self.test_session() as sess:
+ dims = 1
+ vdm = vdm_lib.VectorDiffeomixture(
+ mix_loc=[[-1.], [0.], [1.]],
+ temperature=[0.5],
+ distribution=normal_lib.Normal(0., 1.),
+ loc=[
+ np.float32([-2.]),
+ np.float32([2.]),
+ ],
+ scale=[
+ linop_identity_lib.LinearOperatorScaledIdentity(
+ num_rows=dims,
+ multiplier=np.float32(0.5),
+ is_positive_definite=True),
+ ] * 2, # Use the same scale for each component.
+ quadrature_size=8,
+ validate_args=True)
+
+ samps = vdm.sample(10000)
+ self.assertAllEqual((10000, 3, 1), samps.shape)
+ samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape.
+
+ # One characteristic of putting more weight on a component is that the
+ # mean is closer to that component's mean.
+ # Get the mean for each batch member, the names signify the value of
+ # concentration for that batch member.
+ mean_neg1, mean_0, mean_1 = samps_.mean(axis=0)
+
+ # Since concentration is the concentration for component 0,
+ # concentration = -1 ==> more weight on component 1, which has mean = 2
+ # concentration = 0 ==> equal weight
+ # concentration = 1 ==> more weight on component 0, which has mean = -2
+ self.assertLess(-2, mean_1)
+ self.assertLess(mean_1, mean_0)
+ self.assertLess(mean_0, mean_neg1)
+ self.assertLess(mean_neg1, 2)
+
+ # Run this test as well, just because we can.
+ self.run_test_sample_consistent_mean_covariance(
+ sess.run, vdm, rtol=0.02, cov_rtol=0.08)
+
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
with self.test_session() as sess:
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[0.1],
distribution=normal_lib.Normal(-1., 1.5),
loc=[
np.float32([-2.]),
@@ -177,7 +288,7 @@ class VectorDiffeomixtureTest(
dims = 3
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[[0.], [4.]],
- mix_scale=[10.],
+ temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
np.float32([[-2.]]),
@@ -205,7 +316,7 @@ class VectorDiffeomixtureTest(
dims = 4
vdm = vdm_lib.VectorDiffeomixture(
mix_loc=[0.],
- mix_scale=[1.],
+ temperature=[0.1],
distribution=normal_lib.Normal(0., 1.),
loc=[
None,
@@ -229,29 +340,6 @@ class VectorDiffeomixtureTest(
self.run_test_sample_consistent_log_prob(
sess.run, vdm, radius=4., center=2., rtol=0.005)
- # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
- # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
- # tested that the quadrature approach well-approximates the integral.
- #
- # To that end, consider adding these tests:
- #
- # Test1: In the limit of high mix_scale, this approximates a discrete mixture,
- # and there are many discrete mixtures where we can explicitly compute
- # mean/var, etc... So test1 would choose one of those discrete mixtures and
- # show our mean/var/etc... is close to that.
- #
- # Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1),
- # Normal(5, 1) should (I believe...must check) should look almost like
- # Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5),
- # and (ii) the first few moments should approximately match that of
- # Uniform(-5, 5)
- #
- # Test3: If mix_loc is symmetric, then for any mix_scale, our
- # quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have
- # mean zero, exactly.
-
- # TODO(jvdillon): Add more tests which verify broadcasting.
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
new file mode 100644
index 0000000000..852298bf33
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -0,0 +1,208 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Autoregressive distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import util as distribution_util
+
+
+class Autoregressive(distribution_lib.Distribution):
+ """Autoregressive distributions.
+
+ The Autoregressive distribution enables learning (often) richer multivariate
+ distributions by repeatedly applying a [diffeomorphic](
+ https://en.wikipedia.org/wiki/Diffeomorphism) transformation (such as
+ implemented by `Bijector`s). Regarding terminology,
+
+ "Autoregressive models decompose the joint density as a product of
+ conditionals, and model each conditional in turn. Normalizing flows
+ transform a base density (e.g. a standard Gaussian) into the target density
+ by an invertible transformation with tractable Jacobian." [1]
+
+ In other words, the "autoregressive property" is equivalent to the
+ decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided
+ `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
+ this property by zeroing out weights in its `masked_dense` layers.
+
+ Practically speaking the autoregressive property means that there exists a
+ permutation of the event coordinates such that each coordinate is a
+ diffeomorphic function of only preceding coordinates. [2]
+
+ #### Mathematical Details
+
+ The probability function is,
+
+ ```none
+ prob(x; fn, n) = fn(x).prob(x)
+ ```
+
+ And a sample is generated by,
+
+ ```none
+ x = fn(...fn(fn(x0).sample()).sample()).sample()
+ ```
+
+ where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
+ constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+ fixed initializing `Tensor`.
+
+ #### Examples
+
+ ```python
+ tfd = tf.contrib.distributions
+
+ def normal_fn(self, event_size):
+ n = event_size * (event_size + 1) / 2
+ p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n))
+ affine = tfd.bijectors.Affine(
+ scale_tril=tfd.fill_triangular(0.25 * p))
+ def _fn(samples):
+ scale = math_ops.exp(affine.forward(samples)).eval()
+ return independent_lib.Independent(
+ normal_lib.Normal(loc=0., scale=scale, validate_args=True),
+ reinterpreted_batch_ndims=1)
+ return _fn
+
+ batch_and_event_shape = [3, 2, 4]
+ sample0 = array_ops.zeros(batch_and_event_shape)
+ ar = autoregressive_lib.Autoregressive(
+ self._normal_fn(batch_and_event_shape[-1]), sample0)
+ x = ar.sample([6, 5])
+ # ==> x.shape = [6, 5, 3, 2, 4]
+ prob_x = ar.prob(x)
+ # ==> x.shape = [6, 5, 3, 2]
+
+ ```
+
+ [1]: "Masked Autoregressive Flow for Density Estimation."
+ George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
+ https://arxiv.org/abs/1705.07057
+
+ [2]: "Conditional Image Generation with PixelCNN Decoders."
+ Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex
+ Graves, Koray Kavukcuoglu. Arxiv, 2016.
+ https://arxiv.org/abs/1606.05328
+ """
+
+ def __init__(self,
+ distribution_fn,
+ sample0=None,
+ num_steps=None,
+ validate_args=False,
+ allow_nan_stats=True,
+ name="Autoregressive"):
+ """Construct an `Autoregressive` distribution.
+
+ Args:
+ distribution_fn: Python `callable` which constructs a
+ `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+ `sample0`). The function must respect the "autoregressive property",
+ i.e., there exists a permutation of event such that each coordinate is a
+ diffeomorphic function of on preceding coordinates.
+ sample0: Initial input to `distribution_fn`; used to
+ build the distribution in `__init__` which in turn specifies this
+ distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`.
+ If unspecified, then `distribution_fn` should be default constructable.
+ num_steps: Number of times `distribution_fn` is composed from samples,
+ e.g., `num_steps=2` implies
+ `distribution_fn(distribution_fn(sample0).sample(n)).sample()`.
+ validate_args: Python `bool`. Whether to validate input with asserts.
+ If `validate_args` is `False`, and the inputs are invalid,
+ correct behavior is not guaranteed.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or
+ more of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+ Default value: "Autoregressive".
+
+ Raises:
+ ValueError: if `num_steps` and
+ `distribution_fn(sample0).event_shape.num_elements()` are both `None`.
+ ValueError: if `num_steps < 1`.
+ """
+ parameters = locals()
+ with ops.name_scope(name):
+ self._distribution_fn = distribution_fn
+ self._sample0 = sample0
+ self._distribution0 = (distribution_fn() if sample0 is None
+ else distribution_fn(sample0))
+ if num_steps is None:
+ num_steps = self._distribution0.event_shape.num_elements()
+ if num_steps is None:
+ raise ValueError("distribution_fn must generate a distribution "
+ "with fully known `event_shape`.")
+ if num_steps < 1:
+ raise ValueError("num_steps ({}) must be at least 1.".format(num_steps))
+ self._num_steps = num_steps
+ super(Autoregressive, self).__init__(
+ dtype=self._distribution0.dtype,
+ reparameterization_type=self._distribution0.reparameterization_type,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access
+ name=name)
+
+ @property
+ def distribution_fn(self):
+ return self._distribution_fn
+
+ @property
+ def sample0(self):
+ return self._sample0
+
+ @property
+ def num_steps(self):
+ return self._num_steps
+
+ @property
+ def distribution0(self):
+ return self._distribution0
+
+ def _batch_shape(self):
+ return self.distribution0.batch_shape
+
+ def _batch_shape_tensor(self):
+ return self.distribution0.batch_shape_tensor()
+
+ def _event_shape(self):
+ return self.distribution0.event_shape
+
+ def _event_shape_tensor(self):
+ return self.distribution0.event_shape_tensor()
+
+ def _sample_n(self, n, seed=None):
+ if seed is None:
+ seed = distribution_util.gen_new_seed(
+ seed=np.random.randint(2**32 - 1),
+ salt="autoregressive")
+ samples = self.distribution0.sample(n, seed=seed)
+ for _ in range(self._num_steps):
+ samples = self.distribution_fn(samples).sample(seed=seed)
+ return samples
+
+ def _log_prob(self, value):
+ return self.distribution_fn(value).log_prob(value)
+
+ def _prob(self, value):
+ return self.distribution_fn(value).prob(value)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index bc0ec7f195..93923c3f08 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -29,6 +29,7 @@
@@MaskedAutoregressiveFlow
@@Permute
@@PowerTransform
+@@RealNVP
@@Reshape
@@Sigmoid
@@SigmoidCentered
@@ -39,6 +40,7 @@
@@masked_autoregressive_default_template
@@masked_dense
+@@real_nvp_default_template
"""
from __future__ import absolute_import
@@ -60,6 +62,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
+from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import *
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 06c7c61ec3..dc8ae1eed1 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -182,6 +182,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
shift_and_log_scale_fn,
is_constant_jacobian=False,
validate_args=False,
+ unroll_loop=False,
name=None):
"""Creates the MaskedAutoregressiveFlow bijector.
@@ -201,16 +202,40 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
inefficient.)
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
+ unroll_loop: Python `bool` indicating whether the `tf.while_loop` in
+ `_forward` should be replaced with a static for loop. Requires that
+ the final dimension of `x` be known at graph construction time. Defaults
+ to `False`.
name: Python `str`, name given to ops managed by this object.
"""
name = name or "masked_autoregressive_flow"
self._shift_and_log_scale_fn = shift_and_log_scale_fn
+ self._unroll_loop = unroll_loop
super(MaskedAutoregressiveFlow, self).__init__(
is_constant_jacobian=is_constant_jacobian,
validate_args=validate_args,
name=name)
def _forward(self, x):
+ if self._unroll_loop:
+ event_size = x.shape.with_rank_at_least(1)[-1].value
+ if event_size is None:
+ raise ValueError(
+ "The final dimension of `x` must be known at graph construction "
+ "time if `unroll_loop=True`. `x.shape: %r`" % x.shape)
+ y = array_ops.zeros_like(x, name="y0")
+
+ for _ in range(event_size):
+ shift, log_scale = self._shift_and_log_scale_fn(y)
+ # next_y = scale * x + shift
+ next_y = x
+ if log_scale is not None:
+ next_y *= math_ops.exp(log_scale)
+ if shift is not None:
+ next_y += shift
+ y = next_y
+ return y
+
event_size = array_ops.shape(x)[-1]
y0 = array_ops.zeros_like(x, name="y0")
# call the template once to ensure creation
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
new file mode 100644
index 0000000000..2840f52e74
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -0,0 +1,282 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Real NVP bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core as layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import template as template_ops
+from tensorflow.python.ops.distributions import bijector as bijector_lib
+
+
+__all__ = [
+ "RealNVP",
+ "real_nvp_default_template"
+]
+
+
+class RealNVP(bijector_lib.Bijector):
+ """RealNVP "affine coupling layer" for vector-valued events.
+
+ Real NVP models a normalizing flow on a `D`-dimensional distribution via a
+ single `D-d`-dimensional conditional distribution [1]:
+
+ `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])`
+ `y[0:d] = x[0:d]`
+
+ The last `D-d` units are scaled and shifted based on the first `d` units only,
+ while the first `d` units are 'masked' and left unchanged. Real NVP's
+ `shift_and_log_scale_fn` computes vector-valued quantities. For
+ scale-and-shift transforms that do not depend on any masked units, i.e.
+ `d=0`, use the `tfb.Affine` bijector with learned parameters instead.
+
+ Masking is currently only supported for base distributions with
+ `event_ndims=1`. For more sophisticated masking schemes like checkerboard or
+ channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired
+ masked units into the first `d` units. For base distributions with
+ `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape.
+
+ Recall that the MAF bijector [2] implements a normalizing flow via an
+ autoregressive transformation. MAF and IAF have opposite computational
+ tradeoffs - MAF can train all units in parallel but must sample units
+ sequentially, while IAF must train units sequentially but can sample in
+ parallel. In contrast, Real NVP can compute both forward and inverse
+ computations in parallel. However, the lack of an autoregressive
+ transformations makes it less expressive on a per-bijector basis.
+
+ A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or
+ "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable
+ with the arguments to `forward` and `inverse`, i.e., such that the
+ calculations in `forward`, `inverse` [below] are possible. For convenience,
+ `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn`
+ function.
+
+ NICE [3] is a special case of the Real NVP bijector which discards the scale
+ transformation, resulting in a constant-time inverse-log-determinant-Jacobian.
+ To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should
+ return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in
+ the `RealNVP` constructor. Calling `real_nvp_default_template` with
+ `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`.
+
+ Caching: the scalar input depth `D` of the base distribution is not known at
+ construction time. The first call to any of `forward(x)`, `inverse(x)`,
+ `inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes
+ `D`, which is re-used in subsequent calls. This shape must be known prior to
+ graph execution (which is the case if using tf.layers).
+
+ #### Example Use
+
+ ```python
+ tfd = tf.contrib.distributions
+ tfb = tfd.bijectors
+
+ # A common choice for a normalizing flow is to use a Gaussian for the base
+ # distribution. (However, any continuous distribution would work.) E.g.,
+ nvp = tfd.TransformedDistribution(
+ distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])),
+ bijector=tfb.RealNVP(
+ num_masked=2,
+ shift_and_log_scale_fn=tfb.real_nvp_default_template(
+ hidden_layers=[512, 512])))
+
+ x = nvp.sample()
+ nvp.log_prob(x)
+ nvp.log_prob(0.)
+ ```
+
+ For more examples, see [4].
+
+ [1]: "Density Estimation using Real NVP."
+ Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017.
+ https://arxiv.org/abs/1605.08803
+
+ [2]: "Masked Autoregressive Flow for Density Estimation."
+ George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
+ https://arxiv.org/abs/1705.07057
+
+ [3]: "NICE: Non-linear Independent Components Estimation."
+ Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015.
+ https://arxiv.org/abs/1410.8516
+
+ [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows."
+ Eric Jang. Blog post. January 2018.
+ http://blog.evjang.com/2018/01/nf2.html
+ """
+
+ def __init__(self,
+ num_masked,
+ shift_and_log_scale_fn,
+ is_constant_jacobian=False,
+ validate_args=False,
+ name=None):
+ """Creates the Real NVP or NICE bijector.
+
+ Args:
+ num_masked: Python `int` indicating that the first `d` units of the event
+ should be masked. Must be in the closed interval `[1, D-1]`, where `D`
+ is the event size of the base distribution.
+ shift_and_log_scale_fn: Python `callable` which computes `shift` and
+ `log_scale` from both the forward domain (`x`) and the inverse domain
+ (`y`). Calculation must respect the "autoregressive property" (see class
+ docstring). Suggested default
+ `masked_autoregressive_default_template(hidden_layers=...)`.
+ Typically the function contains `tf.Variables` and is wrapped using
+ `tf.make_template`. Returning `None` for either (both) `shift`,
+ `log_scale` is equivalent to (but more efficient than) returning zero.
+ is_constant_jacobian: Python `bool`. Default: `False`. When `True` the
+ implementation assumes `log_scale` does not depend on the forward domain
+ (`x`) or inverse domain (`y`) values. (No validation is made;
+ `is_constant_jacobian=False` is always safe but possibly computationally
+ inefficient.)
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str`, name given to ops managed by this object.
+
+ Raises:
+ ValueError: If num_masked < 1.
+ """
+ name = name or "real_nvp"
+ if num_masked <= 0:
+ raise ValueError("num_masked must be a positive integer.")
+ self._num_masked = num_masked
+ # At construction time, we don't know input_depth.
+ self._input_depth = None
+ self._shift_and_log_scale_fn = shift_and_log_scale_fn
+ super(RealNVP, self).__init__(
+ event_ndims=1,
+ is_constant_jacobian=is_constant_jacobian,
+ validate_args=validate_args,
+ name=name)
+
+ def _cache_input_depth(self, x):
+ if self._input_depth is None:
+ self._input_depth = x.shape.with_rank_at_least(1)[-1].value
+ if self._input_depth is None:
+ raise NotImplementedError(
+ "Rightmost dimension must be known prior to graph execution.")
+ if self._num_masked >= self._input_depth:
+ raise ValueError(
+ "Number of masked units must be smaller than the event size.")
+
+ def _forward(self, x):
+ self._cache_input_depth(x)
+ # Performs scale and shift.
+ x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:]
+ shift, log_scale = self._shift_and_log_scale_fn(
+ x0, self._input_depth - self._num_masked)
+ y1 = x1
+ if log_scale is not None:
+ y1 *= math_ops.exp(log_scale)
+ if shift is not None:
+ y1 += shift
+ y = array_ops.concat([x0, y1], axis=-1)
+ return y
+
+ def _inverse(self, y):
+ self._cache_input_depth(y)
+ # Performs un-shift and un-scale.
+ y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:]
+ shift, log_scale = self._shift_and_log_scale_fn(
+ y0, self._input_depth - self._num_masked)
+ x1 = y1
+ if shift is not None:
+ x1 -= shift
+ if log_scale is not None:
+ x1 *= math_ops.exp(-log_scale)
+ x = array_ops.concat([y0, x1], axis=-1)
+ return x
+
+ def _inverse_log_det_jacobian(self, y):
+ self._cache_input_depth(y)
+ y0 = y[:, :self._num_masked]
+ _, log_scale = self._shift_and_log_scale_fn(
+ y0, self._input_depth - self._num_masked)
+ if log_scale is None:
+ return constant_op.constant(0., dtype=y.dtype, name="ildj")
+ return -math_ops.reduce_sum(log_scale, axis=-1)
+
+ def _forward_log_det_jacobian(self, x):
+ self._cache_input_depth(x)
+ x0 = x[:, :self._num_masked]
+ _, log_scale = self._shift_and_log_scale_fn(
+ x0, self._input_depth - self._num_masked)
+ if log_scale is None:
+ return constant_op.constant(0., dtype=x.dtype, name="ildj")
+ return math_ops.reduce_sum(log_scale, axis=-1)
+
+
+def real_nvp_default_template(
+ hidden_layers,
+ shift_only=False,
+ activation=nn_ops.relu,
+ name=None,
+ *args,
+ **kwargs):
+ """Build a scale-and-shift function using a multi-layer neural network.
+
+ This will be wrapped in a make_template to ensure the variables are only
+ created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d`
+ dimensional outputs `loc` ("mu") and `log_scale` ("alpha").
+
+ Arguments:
+ hidden_layers: Python `list`-like of non-negative integer, scalars
+ indicating the number of units in each hidden layer. Default: `[512, 512].
+ shift_only: Python `bool` indicating if only the `shift` term shall be
+ computed (i.e. NICE bijector). Default: `False`.
+ activation: Activation function (callable). Explicitly setting to `None`
+ implies a linear activation.
+ name: A name for ops managed by this function. Default:
+ "real_nvp_default_template".
+ *args: `tf.layers.dense` arguments.
+ **kwargs: `tf.layers.dense` keyword arguments.
+
+ Returns:
+ shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]).
+ log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]).
+
+ Raises:
+ NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
+ graph execution.
+ """
+
+ with ops.name_scope(name, "real_nvp_default_template"):
+ def _fn(x, output_units):
+ """Fully connected MLP parameterized via `real_nvp_template`."""
+ for units in hidden_layers:
+ x = layers.dense(
+ inputs=x,
+ units=units,
+ activation=activation,
+ *args,
+ **kwargs)
+ x = layers.dense(
+ inputs=x,
+ units=(1 if shift_only else 2) * output_units,
+ activation=None,
+ *args,
+ **kwargs)
+ if shift_only:
+ return x, None
+ shift, log_scale = array_ops.split(x, 2, axis=-1)
+ return shift, log_scale
+ return template_ops.make_template(
+ "real_nvp_default_template", _fn)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 0ca236c376..49afbea7f0 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -248,7 +248,7 @@ class MixtureSameFamily(distribution.Distribution):
x = self._pad_sample_dims(x)
log_prob_x = self.components_distribution.log_prob(x) # [S, B, k]
log_mix_prob = nn_ops.log_softmax(
- self.mixture_distribution.logits, dim=-1) # [B, k]
+ self.mixture_distribution.logits, axis=-1) # [B, k]
return math_ops.reduce_logsumexp(
log_prob_x + log_mix_prob, axis=-1) # [S, B]
@@ -264,7 +264,7 @@ class MixtureSameFamily(distribution.Distribution):
x = self._pad_sample_dims(x)
log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k]
log_mix_prob = nn_ops.log_softmax(
- self.mixture_distribution.logits, dim=-1) # [B, k]
+ self.mixture_distribution.logits, axis=-1) # [B, k]
return math_ops.reduce_logsumexp(
log_cdf_x + log_mix_prob, axis=-1) # [S, B]
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 7ce8a83fd9..0c747f8e68 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -50,20 +50,25 @@ __all__ = [
def quadrature_scheme_softmaxnormal_gauss_hermite(
- loc, scale, quadrature_size,
+ normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex.
+ A `SoftmaxNormal` random variable `Y` may be generated via
+
+ ```
+ Y = SoftmaxCentered(X),
+ X = Normal(normal_loc, normal_scale)
+ ```
+
Note: for a given `quadrature_size`, this method is generally less accurate
than `quadrature_scheme_softmaxnormal_quantiles`.
Args:
- loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `location` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `scale` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
+ normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
+ The location parameter of the Normal used to construct the SoftmaxNormal.
+ normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
+ The scale parameter of the Normal used to construct the SoftmaxNormal.
quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
@@ -80,24 +85,25 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
associated with each grid point.
"""
with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite",
- [loc, scale]):
- loc = ops.convert_to_tensor(loc, name="loc")
- dt = loc.dtype.base_dtype
- scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
+ [normal_loc, normal_scale]):
+ normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
+ dt = normal_loc.dtype.base_dtype
+ normal_scale = ops.convert_to_tensor(
+ normal_scale, dtype=dt, name="normal_scale")
- loc = maybe_check_quadrature_param(loc, "loc", validate_args)
- scale = maybe_check_quadrature_param(scale, "scale", validate_args)
+ normal_scale = maybe_check_quadrature_param(
+ normal_scale, "normal_scale", validate_args)
grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size)
- grid = grid.astype(loc.dtype.as_numpy_dtype)
- probs = probs.astype(loc.dtype.as_numpy_dtype)
+ grid = grid.astype(dt.dtype.as_numpy_dtype)
+ probs = probs.astype(dt.dtype.as_numpy_dtype)
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
- probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype)
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dt)
grid = softmax(
-distribution_util.pad(
- (loc[..., array_ops.newaxis] +
- np.sqrt(2.) * scale[..., array_ops.newaxis] * grid),
+ (normal_loc[..., array_ops.newaxis] +
+ np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid),
axis=-2,
front=True),
axis=-2) # shape: [B, components, deg]
@@ -106,18 +112,23 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
def quadrature_scheme_softmaxnormal_quantiles(
- loc, scale, quadrature_size,
+ normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
"""Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex.
+ A `SoftmaxNormal` random variable `Y` may be generated via
+
+ ```
+ Y = SoftmaxCentered(X),
+ X = Normal(normal_loc, normal_scale)
+ ```
+
Args:
- loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `location` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
- Represents the `scale` parameter of the SoftmaxNormal used for
- selecting one of the `K` affine transformations.
- quadrature_size: Python scalar `int` representing the number of quadrature
+ normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0.
+ The location parameter of the Normal used to construct the SoftmaxNormal.
+ normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`.
+ The scale parameter of the Normal used to construct the SoftmaxNormal.
+ quadrature_size: Python `int` scalar representing the number of quadrature
points.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
@@ -132,15 +143,17 @@ def quadrature_scheme_softmaxnormal_quantiles(
probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the
associated with each grid point.
"""
- with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]):
- loc = ops.convert_to_tensor(loc, name="loc")
- dt = loc.dtype.base_dtype
- scale = ops.convert_to_tensor(scale, dtype=dt, name="scale")
+ with ops.name_scope(name, "softmax_normal_grid_and_probs",
+ [normal_loc, normal_scale]):
+ normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc")
+ dt = normal_loc.dtype.base_dtype
+ normal_scale = ops.convert_to_tensor(
+ normal_scale, dtype=dt, name="normal_scale")
- loc = maybe_check_quadrature_param(loc, "loc", validate_args)
- scale = maybe_check_quadrature_param(scale, "scale", validate_args)
+ normal_scale = maybe_check_quadrature_param(
+ normal_scale, "normal_scale", validate_args)
- dist = normal_lib.Normal(loc=loc, scale=scale)
+ dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale)
def _get_batch_ndims():
"""Helper to get dist.batch_shape.ndims, statically if possible."""
@@ -195,114 +208,51 @@ def quadrature_scheme_softmaxnormal_quantiles(
class VectorDiffeomixture(distribution_lib.Distribution):
"""VectorDiffeomixture distribution.
- The VectorDiffeomixture is an approximation to a [compound distribution](
- https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e.,
+ A vector diffeomixture (VDM) is a distribution parameterized by a convex
+ combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K`
+ `scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following
+ [compound distribution]
+ (https://en.wikipedia.org/wiki/Compound_probability_distribution)
```none
- p(x) = int_{X} q(x | v) p(v) dv
- = lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k],
- scale=sum_k^K lambda[k;i] scale[k])
- : i=0, ..., Q-1 }
+ p(x) = int p(x | z) p(z) dz,
+ where z is in the K-simplex, and
+ p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
```
- where `q(x | v)` is a vector version of the `distribution` argument and `p(v)`
- is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The
- vector-ization of `distribution` entails an affine transformation of iid
- samples from `distribution`. The `prob` term is from quadrature and
- `lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the
- `grid` points correspond to the `prob`s.
-
- In the non-approximation case, a draw from the mixture distribution (the
- "prior") represents the convex weights for different affine transformations.
- I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final
- sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@`
- denotes matrix multiplication. However, the non-approximate distribution does
- not have an analytical probability density function (pdf). Therefore the
- `VectorDiffeomixture` class implements an approximation based on
- [numerical quadrature](
- https://en.wikipedia.org/wiki/Numerical_integration) (default:
- [Gauss--Hermite quadrature](
- https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
- Note: although the `VectorDiffeomixture` is approximately the
- `SoftmaxNormal-Distribution` compound distribution, it is itself a valid
- distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
- are all mutually consistent.
-
- #### Intended Use
-
- This distribution is noteworthy because it implements a mixture of
- `Vector`-ized distributions yet has samples differentiable in the
- distribution's parameters (aka "reparameterized"). It has an analytical
- density function with `O(dKQ)` complexity. `d` is the vector dimensionality,
- `K` is the number of components, and `Q` is the number of quadrature points.
- These properties make it well-suited for Bayesian Variational Inference, i.e.,
- as a surrogate family for the posterior.
-
- For large values of `mix_scale`, the `VectorDistribution` behaves increasingly
- like a discrete mixture. (In most cases this limit is only achievable by also
- increasing the quadrature polynomial degree, `Q`.)
-
- The term `Vector` is consistent with similar named Tensorflow `Distribution`s.
- For more details, see the "About `Vector` distributions in Tensorflow."
- section.
-
- The term `Diffeomixture` is a portmanteau of
- [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound
- mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For
- more details, see the "About `Diffeomixture`s and reparametrization.`"
- section.
-
- #### Mathematical Details
-
- The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
- [compound distribution](
- https://en.wikipedia.org/wiki/Compound_probability_distribution).
- Using variable-substitution and [numerical quadrature](
- https://en.wikipedia.org/wiki/Numerical_integration) (default:
- [Gauss--Hermite quadrature](
- https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
- redefine the distribution to be a parameter-less convex combination of `K`
- different affine combinations of a `d` iid samples from `distribution`.
-
- That is, defined over `R**d` this distribution is parameterized by a
- (batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of
- (a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale`
- `LinearOperator`s each operating on a (batch of) length-`d` vector space.
- Finally, a `distribution` parameter specifies the underlying base distribution
- which is "lifted" to become multivariate ("lifting" is the same concept as in
- `TransformedDistribution`).
-
- The probability density function (pdf) is,
+ The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme
+ adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}`
+ and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen
+ such that
- ```none
- pdf(y; mix_loc, mix_scale, loc, scale, phi)
- = sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i]))
- : i=0, ..., Q-1 }
- ```
+ ```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)```
- where, `phi` is the base distribution pdf, and,
+ as `N --> infinity`.
- ```none
- f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i])
- interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 }
- interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 }
- ```
+ Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from
+ `q_N` exactly. It is important to note that the VDM is *defined* as `q_N`
+ above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as
+ exact (up to floating point error) methods.
- and,
+ A common choice for the conditional `p(x | z)` is a multivariate Normal.
- ```none
- grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
- prob[k] = weight[k] / sqrt(pi)
- lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
+ The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a
+ `K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making
+ it a density on the `K`-simplex. That is,
+
+ ```
+ Z = SoftmaxCentered(X),
+ X = Normal(mix_loc / temperature, 1 / temperature)
```
- The distribution corresponding to `phi` must be a scalar-batch, scalar-event
- distribution. Typically it is reparameterized. If not, it must be a function
- of non-trainable parameters.
+ The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of
+ the quantiles of `p(z)` (generalized quantiles if `K > 2`).
- WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
- distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
- variables, then the gradient is not guaranteed correct!
+ See [1] for more details.
+
+ [1]. "Quadrature Compound: An approximating family of distributions"
+ Joshua Dillon, Ian Langmore, arXiv preprints
+ https://arxiv.org/abs/1801.03080
#### About `Vector` distributions in TensorFlow.
@@ -310,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution):
particularly useful in [variational Bayesian
methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods).
- Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose
+ Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose
components are linear combinations of affine transformations, thus is itself
- an affine transformation. Therefore `Y|v` lives in the vector space generated
- by vectors of affine-transformed distributions.
+ an affine transformation.
- Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some
+ Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some
parameterization of `distribution`. This is due to the fact that the sum of
draws from `distribution` are not generally itself the same `distribution`.
@@ -331,12 +280,16 @@ class VectorDiffeomixture(distribution_lib.Distribution):
optimize Monte-Carlo objectives. Such objectives are a finite-sample
approximation of an expectation and arise throughout scientific computing.
+ WARNING: If you backprop through a VectorDiffeomixture sample and the "base"
+ distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable
+ variables, then the gradient is not guaranteed correct!
+
#### Examples
```python
tfd = tf.contrib.distributions
- # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and
+ # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
# transformations involve:
# k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity
@@ -344,7 +297,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
dims = 5
vdm = tfd.VectorDiffeomixture(
mix_loc=[[0.], [1]],
- mix_scale=[1.],
+ temperature=[1.],
distribution=tfd.Normal(loc=0., scale=1.),
loc=[
None, # Equivalent to `np.zeros(dims, dtype=np.float32)`.
@@ -364,7 +317,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
def __init__(self,
mix_loc,
- mix_scale,
+ temperature,
distribution,
loc=None,
scale=None,
@@ -373,15 +326,24 @@ class VectorDiffeomixture(distribution_lib.Distribution):
validate_args=False,
allow_nan_stats=True,
name="VectorDiffeomixture"):
- """Constructs the VectorDiffeomixture on `R**d`.
+ """Constructs the VectorDiffeomixture on `R^d`.
+
+ The vector diffeomixture (VDM) approximates the compound distribution
+
+ ```none
+ p(x) = int p(x | z) p(z) dz,
+ where z is in the K-simplex, and
+ p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
+ ```
Args:
- mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents
- the `location` parameter of the SoftmaxNormal used for selecting one of
- the `K` affine transformations.
- mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
- Represents the `scale` parameter of the SoftmaxNormal used for selecting
- one of the `K` affine transformations.
+ mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
+ In terms of samples, larger `mix_loc[..., k]` ==>
+ `Z` is more likely to put more weight on its `kth` component.
+ temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
+ In terms of samples, smaller `temperature` means one component is more
+ likely to dominate. I.e., smaller `temperature` makes the VDM look more
+ like a standard mixture of `K` components.
distribution: `tf.Distribution`-like instance. Distribution from which `d`
iid samples are used as input to the selected affine transformation.
Must be a scalar-batch, scalar-event distribution. Typically
@@ -401,8 +363,9 @@ class VectorDiffeomixture(distribution_lib.Distribution):
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
quadrature_size: Python `int` scalar representing number of
- quadrature points.
- quadrature_fn: Python callable taking `mix_loc`, `mix_scale`,
+ quadrature points. Larger `quadrature_size` means `q_N(x)` better
+ approximates `p(x)`.
+ quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
`quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
representing the SoftmaxNormal grid and corresponding normalized weight.
normalized) weight.
@@ -430,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_event`.
"""
parameters = locals()
- with ops.name_scope(name, values=[mix_loc, mix_scale]):
+ with ops.name_scope(name, values=[mix_loc, temperature]):
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
"LinearOperators, one for each component with "
@@ -473,8 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution):
raise NotImplementedError("Currently only bimixtures are supported; "
"len(scale)={} is not 2.".format(len(scale)))
+ mix_loc = ops.convert_to_tensor(
+ mix_loc, dtype=dtype, name="mix_loc")
+ temperature = ops.convert_to_tensor(
+ temperature, dtype=dtype, name="temperature")
self._grid, probs = tuple(quadrature_fn(
- mix_loc, mix_scale, quadrature_size, validate_args))
+ mix_loc / temperature,
+ 1. / temperature,
+ quadrature_size,
+ validate_args))
# Note: by creating the logits as `log(prob)` we ensure that
# `self.mixture_distribution.logits` is equivalent to
@@ -618,7 +588,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
weight = array_ops.gather(
array_ops.reshape(self.grid, shape=[-1]),
ids + offset)
- weight = weight[..., array_ops.newaxis]
+ # At this point, weight flattened all batch dims into one.
+ # We also need to append a singleton to broadcast with event dims.
+ if self.batch_shape.is_fully_defined():
+ new_shape = [-1] + self.batch_shape.as_list() + [1]
+ else:
+ new_shape = array_ops.concat(
+ ([-1], self.batch_shape_tensor(), [1]), axis=0)
+ weight = array_ops.reshape(weight, shape=new_shape)
if len(x) != 2:
# We actually should have already triggered this exception. However as a
@@ -686,7 +663,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# To compute E[Cov(Z|V)], we'll add matrices within three categories:
# scaled-identity, diagonal, and full. Then we'll combine these at the end.
- scaled_identity = None
+ scale_identity_multiplier = None
diag = None
full = None
@@ -694,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
s = aff.scale # Just in case aff.scale has side-effects, we'll call once.
if (s is None
or isinstance(s, linop_identity_lib.LinearOperatorIdentity)):
- scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis])
+ scale_identity_multiplier = add(scale_identity_multiplier,
+ p[..., k, array_ops.newaxis])
elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity):
- scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] *
- math_ops.square(s.multiplier)))
+ scale_identity_multiplier = add(
+ scale_identity_multiplier,
+ (p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier)))
elif isinstance(s, linop_diag_lib.LinearOperatorDiag):
diag = add(diag, (p[..., k, array_ops.newaxis] *
math_ops.square(s.diag_part())))
@@ -709,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution):
full = add(full, x)
# We must now account for the fact that the base distribution might have a
- # non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`.
+ # non-unity variance. Recall that, since X ~ iid Law(X_0),
+ # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`.
# We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid
# samples from a scalar-event distribution.
v = self.distribution.variance()
- if scaled_identity is not None:
- scaled_identity *= v
+ if scale_identity_multiplier is not None:
+ scale_identity_multiplier *= v
if diag is not None:
diag *= v[..., array_ops.newaxis]
if full is not None:
@@ -723,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
if diag_only:
# Apparently we don't need the full matrix, just the diagonal.
r = add(diag, full)
- if r is None and scaled_identity is not None:
+ if r is None and scale_identity_multiplier is not None:
ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype)
- return scaled_identity * ones
- return add(r, scaled_identity)
+ return scale_identity_multiplier[..., array_ops.newaxis] * ones
+ return add(r, scale_identity_multiplier)
# `None` indicates we don't know if the result is positive-definite.
is_positive_definite = (True if all(aff.scale.is_positive_definite
@@ -742,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
to_add.append(linop_full_lib.LinearOperatorFullMatrix(
matrix=full,
is_positive_definite=is_positive_definite))
- if scaled_identity is not None:
+ if scale_identity_multiplier is not None:
to_add.append(linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=self.event_shape_tensor()[0],
- multiplier=scaled_identity,
+ multiplier=scale_identity_multiplier,
is_positive_definite=is_positive_definite))
return (linop_add_lib.add_operators(to_add)[0].to_dense()
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index fa520c174a..e984c63af7 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -69,6 +69,7 @@ cuda_py_test(
srcs = ["datasets_test.py"],
additional_deps = [
":datasets",
+ "//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py
index f820990bbe..ff419614f5 100644
--- a/tensorflow/contrib/eager/python/checkpointable_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_test.py
@@ -70,42 +70,36 @@ class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable):
checkpointable.Checkpointable.__init__(self)
adam.AdamOptimizer.__init__(self, *args, **kwargs)
- # NOTE: Copied from AdamOptimizer with modifications to use add_variable
+ # NOTE: Copied from Optimizer with modifications to use add_variable
# for non-slot variables. These contortions are necessary to maintain
# checkpoint compatibility with variable.name based saving.
- def _create_slots(self, var_list):
- # Create the beta1 and beta2 accumulators on the same device as the first
- # variable. Sort the var_list to make sure this device is consistent across
- # workers (these need to go on the same PS, otherwise some updates are
- # silently ignored).
- first_var = min(var_list, key=lambda x: x.name)
-
- create_new = self._beta1_power is None
- if not create_new and context.in_graph_mode():
- create_new = (self._beta1_power.graph is not first_var.graph)
-
- if create_new:
- with ops.colocate_with(first_var):
+ # TODO(allenl): Make this cleaner.
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ """Add an extra variable, not associated with a slot."""
+ if context.in_graph_mode():
+ graph = colocate_with.graph
+ else:
+ graph = None
+ key = (name, graph)
+ v = self._non_slot_dict.get(key, None)
+ if v is None:
+ with ops.colocate_with(colocate_with):
def _variable_getter(name, shape, dtype, initializer):
del shape, dtype # not used, but there for compatibility
return variable_scope.variable(
name=name, initial_value=initializer, trainable=False)
- self._beta1_power = self.add_variable(
- name="beta1_power",
- shape=[],
- initializer=self._beta1,
+ initial_value = ops.convert_to_tensor(initial_value)
+ v = self.add_variable(
+ name=name,
+ shape=initial_value.get_shape(),
+ initializer=initial_value,
getter=_variable_getter)
- self._beta2_power = self.add_variable(
- name="beta2_power",
- shape=[],
- initializer=self._beta2,
- getter=_variable_getter)
- # Create slots for the first and second moments.
- for v in var_list:
- self._zeros_slot(v, "m", self._name)
- self._zeros_slot(v, "v", self._name)
+
+ self._non_slot_dict[key] = v
+
+ return v
# TODO(allenl): Override slot variable creation (_get_or_make_slot,
# _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 7ca93d37f2..544a3eafc0 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -42,7 +42,7 @@ def _generate_shared_name(prefix):
global _uid_counter
uid = _uid_counter
_uid_counter += 1
- return "{}_{}".format(prefix, uid)
+ return "{}{}".format(prefix, uid)
class Iterator(object):
@@ -84,8 +84,8 @@ class Iterator(object):
self._flat_output_shapes = nest.flatten(
sparse.as_dense_shapes(self._output_shapes, self._output_classes))
self._resource = gen_dataset_ops.iterator(
- container="",
- shared_name=_generate_shared_name("eager_iterator"),
+ shared_name="",
+ container=_generate_shared_name("eageriterator"),
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
@@ -141,7 +141,12 @@ class Iterator(object):
# TODO(ashankar): Consider removing this ops.device() contextmanager
# and instead mimic ops placement in graphs: Operations on resource
# handles execute on the same device as where the resource is placed.
- ret = gen_dataset_ops.iterator_get_next(
+ # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
+ # because in eager mode this code will run synchronously on the calling
+ # thread. Therefore we do not need to make a defensive context switch
+ # to a background thread, and can achieve a small constant performance
+ # boost by invoking the iterator synchronously.
+ ret = gen_dataset_ops.iterator_get_next_sync(
self._resource,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index 7e68c1a0b6..a1611e92b1 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -20,9 +20,11 @@ import time
import numpy as np
+from tensorflow.contrib import lookup
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.data import Dataset
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -79,6 +81,18 @@ class IteratorTest(test.TestCase):
got = [x.numpy() for x in it]
self.assertAllEqual([0, 4, 16, 36], got)
+ def testMapCaptureLookupTable(self):
+ default_val = -1
+ keys = constant_op.constant(['brain', 'salad', 'surgery'])
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
+ table = lookup.HashTable(
+ lookup.KeyValueTensorInitializer(keys, values), default_val)
+ dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery'])
+ dataset = dataset.map(table.lookup)
+ it = datasets.Iterator(dataset)
+ got = [x.numpy() for x in it]
+ self.assertAllEqual([0, 1, 2], got)
+
def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py
index 3faaeef590..68e7b5421f 100644
--- a/tensorflow/contrib/eager/python/evaluator.py
+++ b/tensorflow/contrib/eager/python/evaluator.py
@@ -178,7 +178,7 @@ class Evaluator(object):
call_op: An op that updates evaluation state on a mini-batch of examples.
Must generate an tf.errors.OutOfRangeError when done.
results_op: A dictionary of tensors that compute the final evaluation
- results from the evaulation state.
+ results from the evaluation state.
sess: The Session to run the evaluation in. Defaults to the default
Session.
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
index bab7ad0c70..f86331af6f 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
@@ -23,3 +23,13 @@ cuda_py_test(
"//tensorflow:tensorflow_py",
],
)
+
+cuda_py_test(
+ name = "linear_regression_graph_test",
+ size = "small",
+ srcs = ["linear_regression_graph_test.py"],
+ additional_deps = [
+ ":linear_regression",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
index f4b7d67f94..6ce4de6ee0 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
@@ -63,6 +63,10 @@ class LinearModel(tfe.Network):
return self._hidden_layer(xs)
+def mean_square_loss(model, xs, ys):
+ return tf.reduce_mean(tf.square(model(xs) - ys))
+
+
def fit(model, dataset, optimizer, verbose=False, logdir=None):
"""Fit the linear-regression model.
@@ -76,10 +80,8 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
"""
# The loss function to optimize.
- def mean_square_loss(xs, ys):
- return tf.reduce_mean(tf.square(model(xs) - ys))
-
- loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss)
+ mse = lambda xs, ys: mean_square_loss(model, xs, ys)
+ loss_and_grads = tfe.implicit_value_and_gradients(mse)
tf.train.get_or_create_global_step()
if logdir:
@@ -103,14 +105,20 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
"""tf.data.Dataset that yields synthetic data for linear regression."""
+ return synthetic_dataset_helper(w, b,
+ tf.shape(w)[0], noise_level, batch_size,
+ num_batches)
+
+def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size,
+ num_batches):
# w is a matrix with shape [N, M]
# b is a vector with shape [M]
# So:
# - Generate x's as vectors with shape [batch_size N]
# - y = tf.matmul(x, W) + b + noise
def batch(_):
- x = tf.random_normal([batch_size, tf.shape(w)[0]])
+ x = tf.random_normal([batch_size, num_features])
y = tf.matmul(x, w) + b + noise_level * tf.random_normal([])
return x, y
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py
new file mode 100644
index 0000000000..557ad42752
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py
@@ -0,0 +1,85 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Graph benchmark for linear regression, to contrast with eager execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression
+
+
+class GraphLinearRegressionBenchmark(tf.test.Benchmark):
+
+ def benchmarkGraphLinearRegression(self):
+ num_epochs = 10
+ num_batches = 200
+ batch_size = 64
+ dataset = linear_regression.synthetic_dataset_helper(
+ w=tf.random_uniform([3, 1]),
+ b=tf.random_uniform([1]),
+ num_features=3,
+ noise_level=0.01,
+ batch_size=batch_size,
+ num_batches=num_batches)
+ iterator = dataset.make_initializable_iterator()
+ x, y = iterator.get_next()
+
+ model = linear_regression.LinearModel()
+
+ if tf.test.is_gpu_available():
+ use_gpu = True
+ device = "/device:GPU:0"
+ else:
+ use_gpu = False
+ device = "/device:CPU:0"
+
+ with tf.device(device):
+ loss = linear_regression.mean_square_loss(model, x, y)
+ optimization_step = tf.train.GradientDescentOptimizer(
+ learning_rate=0.1).minimize(loss)
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+
+ def train(num_epochs):
+ for _ in range(num_epochs):
+ sess.run(iterator.initializer)
+ try:
+ while True:
+ _, _ = sess.run([optimization_step, loss])
+ except tf.errors.OutOfRangeError:
+ pass
+
+ # Warmup: a single epoch.
+ train(1)
+
+ start_time = time.time()
+ train(num_epochs)
+ wall_time = time.time() - start_time
+
+ examples_per_sec = num_epochs * num_batches * batch_size / wall_time
+ self.report_benchmark(
+ name="graph_train_%s" %
+ ("gpu" if use_gpu else "cpu"),
+ iters=num_epochs * num_batches,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py
index 39e7aabd7b..e53234b51a 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py
@@ -83,6 +83,7 @@ class LinearRegressionTest(tf.test.TestCase):
class EagerLinearRegressionBenchmark(tf.test.Benchmark):
def benchmarkEagerLinearRegression(self):
+ num_epochs = 10
num_batches = 200
batch_size = 64
dataset = linear_regression.synthetic_dataset(
@@ -102,14 +103,15 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark):
linear_regression.fit(model, burn_in_dataset, optimizer)
start_time = time.time()
- linear_regression.fit(model, dataset, optimizer)
+ for _ in range(num_epochs):
+ linear_regression.fit(model, dataset, optimizer)
wall_time = time.time() - start_time
- examples_per_sec = num_batches * batch_size / wall_time
+ examples_per_sec = num_epochs * num_batches * batch_size / wall_time
self.report_benchmark(
name="eager_train_%s" %
("gpu" if tfe.num_gpus() > 0 else "cpu"),
- iters=num_batches,
+ iters=num_epochs * num_batches,
extras={"examples_per_sec": examples_per_sec},
wall_time=wall_time)
diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py
index 82b3d3919c..2a7be95811 100644
--- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py
+++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py
@@ -23,7 +23,6 @@ from __future__ import division
from __future__ import print_function
import argparse
-import functools
import os
import sys
import time
@@ -124,21 +123,18 @@ def train_one_epoch(model, optimizer, dataset, log_interval=None):
tf.train.get_or_create_global_step()
- def model_loss(labels, images):
- prediction = model(images, training=True)
- loss_value = loss(prediction, labels)
- tf.contrib.summary.scalar('loss', loss_value)
- tf.contrib.summary.scalar('accuracy',
- compute_accuracy(prediction, labels))
- return loss_value
-
for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
with tf.contrib.summary.record_summaries_every_n_global_steps(10):
- batch_model_loss = functools.partial(model_loss, labels, images)
- optimizer.minimize(
- batch_model_loss, global_step=tf.train.get_global_step())
+ with tfe.GradientTape() as tape:
+ prediction = model(images, training=True)
+ loss_value = loss(prediction, labels)
+ tf.contrib.summary.scalar('loss', loss_value)
+ tf.contrib.summary.scalar('accuracy',
+ compute_accuracy(prediction, labels))
+ grads = tape.gradient(loss_value, model.variables)
+ optimizer.apply_gradients(zip(grads, model.variables))
if log_interval and batch % log_interval == 0:
- print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss()))
+ print('Batch #%d\tLoss: %.6f' % (batch, loss_value))
def test(model, dataset):
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/README.md b/tensorflow/contrib/eager/python/examples/resnet50/README.md
index db023e6c97..79e4600529 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/README.md
+++ b/tensorflow/contrib/eager/python/examples/resnet50/README.md
@@ -34,7 +34,7 @@ bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=.
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
-On October 31, 2017, the benchmarks demostrated comparable performance
+On October 31, 2017, the benchmarks demonstrated comparable performance
for eager and graph execution of this particular model when using
a single NVIDIA Titan X (Pascal) GPU on a host with an
Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32.
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index b302a87e0e..9982fdb07e 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -97,7 +97,7 @@ class _ConvBlock(tfe.Network):
Args:
kernel_size: the kernel size of middle conv layer at main path
- filters: list of integers, the filterss of 3 conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
data_format: data_format for the input ('channels_first' or
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index e2ae665a74..76e06269b6 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -52,14 +52,13 @@ def random_batch(batch_size):
def train_one_step(model, images, labels, optimizer):
- def model_loss():
+ with tfe.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
- return loss
-
- optimizer.minimize(model_loss)
+ grads = tape.gradient(loss, model.variables)
+ optimizer.apply_gradients(zip(grads, model.variables))
class ResNet50Test(tf.test.TestCase):
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md
index 743ebb68ee..966177e91c 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md
@@ -40,7 +40,7 @@ bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=.
(Or remove the `--config=cuda` flag for running on CPU instead of GPU).
-On October 31, 2017, the benchmarks demostrated slightly better performance
+On October 31, 2017, the benchmarks demonstrated slightly better performance
(3-6%) for graph execution over eager execution for this particular model when
using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650
CPU @ 3.50GHz and a batch size of 32.
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index 7b9637a9d5..d34e9ea68b 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -88,7 +88,7 @@ class Embedding(tf.layers.Layer):
class PTBModel(tfe.Network):
- """LSTM for word language modelling.
+ """LSTM for word language modeling.
Model described in:
(Zaremba, et. al.) Recurrent Neural Network Regularization
@@ -340,7 +340,7 @@ if __name__ == "__main__":
parser.add_argument(
"--logdir", type=str, default="", help="Directory for checkpoint.")
parser.add_argument(
- "--epoch", type=int, default=20, help="Number of epoches.")
+ "--epoch", type=int, default=20, help="Number of epochs.")
parser.add_argument("--batch-size", type=int, default=20, help="Batch size.")
parser.add_argument(
"--seq-len", type=int, default=35, help="Sequence length.")
diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py
index a6e046320f..fcaae0a4f8 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/data.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/data.py
@@ -51,11 +51,11 @@ def get_non_parenthesis_words(items):
"""Get the non-parenthesis items from a SNLI parsed sentence.
Args:
- items: Data items from a parsed SNLI setence, with parentheses. E.g.,
+ items: Data items from a parsed SNLI sentence, with parentheses. E.g.,
["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ...
Returns:
- A list of non-parenthis word items, all converted to lower case. E.g.,
+ A list of non-parentheses word items, all converted to lower case. E.g.,
["man", "wearing", "pass", ...
"""
return [x.lower() for x in items if x not in PARENTHESES and x]
@@ -201,7 +201,7 @@ def load_word_vectors(data_root, vocab):
def calculate_bins(length2count, min_bin_size):
- """Cacluate bin boundaries given a histogram of lengths and mininum bin size.
+ """Calculate bin boundaries given a histogram of lengths and minimum bin size.
Args:
length2count: A `dict` mapping length to sentence count.
@@ -335,9 +335,9 @@ class SnliData(object):
# The sorting above and the batching here makes sure that sentences of
# similar max lengths are batched together, minimizing the inefficiency
# due to uneven max lengths. The sentences are batched differently in
- # each call to get_generator() due to the shuffling before sotring
+ # each call to get_generator() due to the shuffling before sorting
# above. The pad_and_reverse_word_ids() and pad_transitions() functions
- # take care of any remaning unevenness of the max sentence lengths.
+ # take care of any remaining unevenness of the max sentence lengths.
end = min(begin + batch_size, len(labels))
# Transpose, because the SPINN model requires time-major, instead of
# batch-major.
diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md
index 0095ffa0db..7eea93ce1f 100644
--- a/tensorflow/contrib/eager/python/g3doc/guide.md
+++ b/tensorflow/contrib/eager/python/g3doc/guide.md
@@ -292,7 +292,7 @@ def loss(weight, bias):
error = prediction(training_inputs, weight, bias) - training_outputs
return tf.reduce_mean(tf.square(error))
-# Function that returns the the derivative of loss with respect to
+# Function that returns the derivative of loss with respect to
# weight and bias
grad = tfe.gradients_function(loss)
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 8e6b947e5c..81c77e41ac 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -688,7 +688,7 @@ class NetworkTest(test.TestCase):
net2(one)
# Layer names typically are globally unique rather than being unique within
# the scope of their first use. However, within a Network they must be named
- # locally so that previous Layer consutrciton does not interfere with
+ # locally so that previous Layer construction does not interfere with
# variable naming (e.g. add a Layer construction before the Network,
# suddenly your previously saved checkpoint is incompatible).
self.assertEqual("dense", net1.l1.name)
diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py
index 57b070ec6e..62421849c7 100644
--- a/tensorflow/contrib/eager/python/saver.py
+++ b/tensorflow/contrib/eager/python/saver.py
@@ -82,7 +82,7 @@ def restore_variables_on_create(save_path, map_func=None):
map_func_wrapper = lambda self, x: x
else:
if not callable(map_func):
- raise ValueError("map_func must be callaled.")
+ raise ValueError("map_func must be callable.")
map_func_wrapper = lambda self, x: map_func(x)
ckpt_var_cache = dict()
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index cdbe05e4d2..6cdbed5b89 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -163,7 +163,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
@@ -177,7 +177,6 @@ py_library(
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/estimator:util",
"//tensorflow/python/ops/losses",
"//tensorflow/python/saved_model:signature_constants",
],
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index d6ca33e189..238cf287b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
@@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -45,6 +43,7 @@ def multi_class_head(n_classes,
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for multi class classification.
@@ -65,6 +64,12 @@ def multi_class_head(n_classes,
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`binary_classification_head`).
@@ -79,6 +84,7 @@ def multi_class_head(n_classes,
`label_vocabulary` is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -94,12 +100,17 @@ def multi_class_head(n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
def binary_classification_head(
- weight_column=None, thresholds=None, label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM, name=None):
+ weight_column=None,
+ thresholds=None,
+ label_vocabulary=None,
+ loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
+ name=None):
"""Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -119,6 +130,12 @@ def binary_classification_head(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -136,6 +153,7 @@ def binary_classification_head(
is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -151,12 +169,14 @@ def binary_classification_head(
thresholds=thresholds,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
def regression_head(weight_column=None,
label_dimension=1,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for regression using the `mean_squared_error` loss.
@@ -175,6 +195,10 @@ def regression_head(weight_column=None,
`[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
`[D0, D1, ... DN, label_dimension]`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, label_dimension]`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -185,6 +209,7 @@ def regression_head(weight_column=None,
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -198,6 +223,7 @@ def regression_head(weight_column=None,
weight_column=weight_column,
label_dimension=label_dimension,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -220,7 +246,7 @@ def multi_label_head(n_classes,
`batch_size`.
The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
- applications, the shape is `[batch_size, label_n_classes]`.
+ applications, the shape is `[batch_size, n_classes]`.
Labels can be:
* A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
@@ -287,7 +313,7 @@ def multi_label_head(n_classes,
'Length of label_vocabulary must be n_classes ({}). '
'Given: {}'.format(n_classes, len(label_vocabulary)))
if loss_fn:
- _validate_loss_fn_args(loss_fn)
+ head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
@@ -371,9 +397,9 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
labels=processed_labels, logits=logits,
expected_labels_dimension=self.logits_dimension)
if self._loss_fn:
- unweighted_loss = _call_loss_fn(
+ unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access
loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
- features=features)
+ features=features, expected_loss_dim=1)
else:
unweighted_loss = losses.sigmoid_cross_entropy(
multi_class_labels=processed_labels, logits=logits,
@@ -392,8 +418,32 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
processed_labels=processed_labels)
def create_estimator_spec(
- self, features, mode, logits, labels=None, train_op_fn=None):
- """See `Head`."""
+ self, features, mode, logits, labels=None, train_op_fn=None,
+ regularization_losses=None):
+ """Returns an `EstimatorSpec`.
+
+ Args:
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
+ mode: Estimator's `ModeKeys`.
+ logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`.
+ For many applications, the shape is `[batch_size, n_classes]`.
+ labels: Labels with shape matching `logits`. Can be multi-hot `Tensor`
+ with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with
+ `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when
+ `mode` equals `TRAIN` or `EVAL`.
+ train_op_fn: Function that takes a scalar loss `Tensor` and returns
+ `train_op`. Required in TRAIN mode.
+ regularization_losses: A list of additional scalar losses to be added to
+ the training loss, such as regularization losses. These losses are
+ usually expressed as a batch average, so for best results users need to
+ set `loss_reduction=SUM_OVER_BATCH_SIZE` or
+ `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
+ avoid scaling errors.
+ Returns:
+ `EstimatorSpec`.
+ Raises:
+ ValueError: If `train_op_fn` is `None` in TRAIN mode.
+ """
with ops.name_scope(self._name, 'head'):
logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access
@@ -422,18 +472,26 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
(training_loss, unreduced_loss, weights,
processed_labels) = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
+ if regularization_losses:
+ regularization_loss = math_ops.add_n(regularization_losses)
+ regularized_training_loss = math_ops.add_n(
+ [training_loss, regularization_loss])
+ else:
+ regularization_loss = None
+ regularized_training_loss = training_loss
# Eval.
if mode == model_fn.ModeKeys.EVAL:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
- loss=training_loss,
+ loss=regularized_training_loss,
eval_metric_ops=self._eval_metric_ops(
labels=processed_labels,
probabilities=probabilities,
weights=weights,
- unreduced_loss=unreduced_loss))
+ unreduced_loss=unreduced_loss,
+ regularization_loss=regularization_loss))
# Train.
if train_op_fn is None:
@@ -447,25 +505,31 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
else:
mean_loss = None
with ops.name_scope(''):
+ keys = metric_keys.MetricKeys
summary.scalar(
- head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access
- training_loss)
+ head_lib._summary_key(self._name, keys.LOSS), # pylint:disable=protected-access
+ regularized_training_loss)
if mean_loss is not None:
summary.scalar(
- head_lib._summary_key( # pylint:disable=protected-access
- self._name, metric_keys.MetricKeys.LOSS_MEAN),
+ head_lib._summary_key(self._name, keys.LOSS_MEAN), # pylint:disable=protected-access
mean_loss)
+ if regularization_loss is not None:
+ summary.scalar(
+ head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access
+ regularization_loss)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
- loss=training_loss,
- train_op=train_op_fn(training_loss))
+ loss=regularized_training_loss,
+ train_op=train_op_fn(regularized_training_loss))
- def _eval_metric_ops(self, labels, probabilities, weights, unreduced_loss):
+ def _eval_metric_ops(
+ self, labels, probabilities, weights, unreduced_loss,
+ regularization_loss):
"""Returns a dict of metrics for eval_metric_ops."""
with ops.name_scope(
None, 'metrics',
- [labels, probabilities, weights, unreduced_loss]):
+ [labels, probabilities, weights, unreduced_loss, regularization_loss]):
keys = metric_keys.MetricKeys
metric_ops = {
# Estimator already adds a metric for loss.
@@ -482,6 +546,13 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
weights=weights, curve='PR',
name=keys.AUC_PR),
}
+ if regularization_loss is not None:
+ loss_regularization_key = head_lib._summary_key( # pylint:disable=protected-access
+ self._name, keys.LOSS_REGULARIZATION)
+ metric_ops[loss_regularization_key] = (
+ metrics_lib.mean(
+ values=regularization_loss,
+ name=keys.LOSS_REGULARIZATION))
for threshold in self._thresholds:
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
metric_ops[head_lib._summary_key(self._name, accuracy_key)] = ( # pylint:disable=protected-access
@@ -510,52 +581,3 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
threshold=threshold,
name=recall_key))
return metric_ops
-
-
-def _validate_loss_fn_args(loss_fn):
- """Validates loss_fn arguments.
-
- Required arguments: labels, logits.
- Optional arguments: features.
-
- Args:
- loss_fn: The loss function.
- Raises:
- ValueError: If the signature is unexpected.
- """
- loss_fn_args = util.fn_args(loss_fn)
- for required_arg in ['labels', 'logits']:
- if required_arg not in loss_fn_args:
- raise ValueError(
- 'loss_fn must contain argument: {}. '
- 'Given arguments: {}'.format(required_arg, loss_fn_args))
- invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
- if invalid_args:
- raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
-
-
-def _call_loss_fn(loss_fn, labels, logits, features):
- """Calls loss_fn and checks the returned shape.
-
- Args:
- loss_fn: The loss function.
- labels: Processed labels Tensor.
- logits: Logits Tensor of shape [batch_size, logits_dimension].
- features: Features dict.
- Returns:
- Loss Tensor with shape [batch_size, 1].
- """
- loss_fn_args = util.fn_args(loss_fn)
- kwargs = {}
- if 'features' in loss_fn_args:
- kwargs['features'] = features
- unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
- batch_size = array_ops.shape(logits)[0]
- loss_shape = array_ops.shape(unweighted_loss)
- check_shape_op = control_flow_ops.Assert(
- math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])),
- data=[
- 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ',
- loss_shape])
- with ops.control_dependencies([check_shape_op]):
- return array_ops.identity(unweighted_loss)
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index e39e44541d..43cdfec968 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -381,8 +381,8 @@ class MultiLabelHead(test.TestCase):
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'loss_fn must return Tensor of shape \[batch_size, 1\]\. '
- r'Given: \] \[2\]'):
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'):
actual_training_loss.eval()
def test_eval_labels_none(self):
@@ -399,12 +399,13 @@ class MultiLabelHead(test.TestCase):
def _test_eval(
self, head, logits, labels, expected_loss, expected_metrics,
- features=None):
+ features=None, regularization_losses=None):
spec = head.create_estimator_spec(
features=features or {},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
- labels=labels)
+ labels=labels,
+ regularization_losses=regularization_losses)
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
@@ -486,6 +487,38 @@ class MultiLabelHead(test.TestCase):
expected_loss=expected_loss,
expected_metrics=expected_metrics)
+ def test_eval_with_regularization_losses(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(
+ n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ regularization_losses = [1.5, 0.5]
+ expected_regularization_loss = 2.
+ # unregularized_loss = sum(
+ # labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))) / batch_size
+ expected_unregularized_loss = np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2.
+ expected_regularized_loss = (
+ expected_unregularized_loss + expected_regularization_loss)
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ keys.LOSS_MEAN: expected_unregularized_loss,
+ keys.LOSS_REGULARIZATION: expected_regularization_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ }
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels,
+ expected_loss=expected_regularized_loss,
+ expected_metrics=expected_metrics,
+ regularization_losses=regularization_losses)
+
def test_eval_with_label_vocabulary(self):
n_classes = 2
head = head_lib.multi_label_head(
@@ -829,6 +862,49 @@ class MultiLabelHead(test.TestCase):
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
+ def test_train_with_regularization_losses(self):
+ head = head_lib.multi_label_head(
+ n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+ labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ regularization_losses = [1.5, 0.5]
+ # For large logits, sigmoid cross entropy loss is approximated as:
+ # loss = labels * (logits < 0) * (-logits) +
+ # (1 - labels) * (logits > 0) * logits =>
+ # expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # Average over classes and over batch and add regularization loss.
+ expected_loss = 35. / 4. + 2.
+ expected_summaries = {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ metric_keys.MetricKeys.LOSS_REGULARIZATION: 2.,
+ }
+ expected_train_result = 'my_train_op'
+ def _train_op_fn(loss):
+ return string_ops.string_join(
+ [constant_op.constant(expected_train_result),
+ string_ops.as_string(loss, precision=3)])
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ regularization_losses=regularization_losses)
+
+ # Assert predictions, loss, train_op, and summaries.
+ tol = 1e-3
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNotNone(spec.scaffold.summary_op)
+ loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
+ self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
+ self.assertEqual(
+ six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+ train_result)
+ _assert_simple_summaries(self, expected_summaries, summary_str, tol)
+
def test_train_with_weights(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 9a5413fc3f..4d0f9b2424 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -25,6 +25,7 @@ import time
from tensorflow.contrib.factorization.python.ops import clustering_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@@ -207,6 +209,15 @@ class _ModelFn(object):
training_hooks.append(
_LossRelativeChangeHook(loss, self._relative_tolerance))
+ export_outputs = {
+ KMeansClustering.ALL_DISTANCES:
+ export_output.PredictOutput(all_distances[0]),
+ KMeansClustering.CLUSTER_INDEX:
+ export_output.PredictOutput(model_predictions[0]),
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_output.PredictOutput(model_predictions[0])
+ }
+
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions={
@@ -216,7 +227,8 @@ class _ModelFn(object):
loss=loss,
train_op=training_op,
eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)},
- training_hooks=training_hooks)
+ training_hooks=training_hooks,
+ export_outputs=export_outputs)
# TODO(agarwal,ands): support sharded input.
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
index 92fad70b1f..5ab57ca4cd 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc
@@ -44,7 +44,7 @@ const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"};
void Decode(OpKernelContext* context,
const tensorflow::StringPiece& file_contents,
const string& file_format, const int32 samples_per_second,
- const int32 channel_count) {
+ const int32 channel_count, const string& stream) {
// Write the input data to a temp file.
const string temp_filename = io::GetTempFilename(file_format);
OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents));
@@ -54,7 +54,7 @@ void Decode(OpKernelContext* context,
std::vector<float> output_samples;
Status result =
ffmpeg::ReadAudioFile(temp_filename, file_format, samples_per_second,
- channel_count, &output_samples);
+ channel_count, stream, &output_samples);
if (result.code() == error::Code::NOT_FOUND) {
OP_REQUIRES(
context, result.ok(),
@@ -99,7 +99,12 @@ void Decode(OpKernelContext* context,
*/
class DecodeAudioOpV2 : public OpKernel {
public:
- explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {
+ string stream;
+ if (context->GetAttr("stream", &stream).ok()) {
+ stream_ = stream;
+ }
+ }
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
@@ -153,8 +158,12 @@ class DecodeAudioOpV2 : public OpKernel {
errors::InvalidArgument("channel_count must be positive, but got: ",
channel_count));
- Decode(context, contents, file_format, samples_per_second, channel_count);
+ Decode(context, contents, file_format, samples_per_second, channel_count,
+ stream_);
}
+
+ private:
+ string stream_;
};
REGISTER_KERNEL_BUILDER(Name("DecodeAudioV2").Device(DEVICE_CPU),
@@ -166,6 +175,7 @@ REGISTER_OP("DecodeAudioV2")
.Input("samples_per_second: int32")
.Input("channel_count: int32")
.Output("sampled_audio: float")
+ .Attr("stream: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
const Tensor* channels_tensor = c->input_tensor(3);
if (channels_tensor == nullptr) {
@@ -237,7 +247,7 @@ class DecodeAudioOp : public OpKernel {
const tensorflow::StringPiece file_contents = contents.scalar<string>()();
Decode(context, file_contents, file_format_, samples_per_second_,
- channel_count_);
+ channel_count_, "");
}
private:
diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
index 0d7c9cb99e..3dc663bb6f 100644
--- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
+++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py
@@ -33,7 +33,8 @@ class DecodeAudioOpTest(test.TestCase):
def _loadFileAndTest(self, filename, file_format, duration_sec,
samples_per_second, channel_count,
- samples_per_second_tensor=None, feed_dict=None):
+ samples_per_second_tensor=None, feed_dict=None,
+ stream=None):
"""Loads an audio file and validates the output tensor.
Args:
@@ -49,6 +50,9 @@ class DecodeAudioOpTest(test.TestCase):
feed_dict: Used when evaluating the `decode_audio` op. If not
provided, will be empty. Useful when providing a placeholder for
`samples_per_second_tensor`.
+ stream: A string specifying which stream from the content file
+ should be decoded. The default value is '' which leaves the
+ decision to ffmpeg.
"""
if samples_per_second_tensor is None:
samples_per_second_tensor = samples_per_second
@@ -62,7 +66,7 @@ class DecodeAudioOpTest(test.TestCase):
contents,
file_format=file_format,
samples_per_second=samples_per_second_tensor,
- channel_count=channel_count)
+ channel_count=channel_count, stream=stream)
audio = audio_op.eval(feed_dict=feed_dict or {})
self.assertEqual(len(audio.shape), 2)
self.assertNear(
@@ -72,6 +76,17 @@ class DecodeAudioOpTest(test.TestCase):
0.1 * audio.shape[0])
self.assertEqual(audio.shape[1], channel_count)
+ def testStreamIdentifier(self):
+ # mono_16khz_mp3_32khz_aac.mp4 was generated from:
+ # ffmpeg -i tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4 \
+ # -i tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4 \
+ # -strict -2 -map 0:a -map 1:a \
+ # tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4
+ self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000,
+ 1, stream='0')
+ self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000,
+ 1, stream='1')
+
def testMonoMp3(self):
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1)
self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2)
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 1e8af1458c..c85b1837ab 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -44,8 +44,10 @@ std::vector<string> FfmpegAudioCommandLine(const string& input_filename,
const string& output_filename,
const string& input_format_id,
int32 samples_per_second,
- int32 channel_count) {
- return {"-nostats", // No additional progress display.
+ int32 channel_count,
+ const string& stream) {
+ std::vector<string> command({
+ "-nostats", // No additional progress display.
"-nostdin", // No interactive commands accepted.
"-f", input_format_id, // eg: "mp3"
"-probesize", StrCat(kDefaultProbeSize), "-i", input_filename,
@@ -58,8 +60,15 @@ std::vector<string> FfmpegAudioCommandLine(const string& input_filename,
// Output set (in several ways) to signed 16-bit little-endian ints.
"-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le",
"-sn", // No subtitle recording.
- "-y", // Overwrite output file.
- StrCat(output_filename)};
+ "-y" // Overwrite output file.
+ });
+ if (!stream.empty()) {
+ command.emplace_back("-map");
+ command.emplace_back(StrCat("0:", stream));
+ }
+ command.emplace_back(StrCat(output_filename));
+
+ return command;
}
std::vector<string> FfmpegVideoCommandLine(const string& input_filename,
@@ -73,7 +82,9 @@ std::vector<string> FfmpegVideoCommandLine(const string& input_filename,
"-probesize",
StrCat(kDefaultProbeSize),
"-loglevel",
- "error", // Print errors only.
+ // Info is needed to get the information about stream, etc.
+ // It is generated to a separate file, not stdout/stderr.
+ "info",
"-hide_banner", // Skip printing build options, version, etc.
"-vcodec",
"rawvideo",
@@ -123,7 +134,6 @@ bool IsBinaryInstalled(const string& binary_name) {
std::transform(args.begin(), args.end(), std::back_inserter(args_chars),
[](const string& s) { return const_cast<char*>(s.c_str()); });
args_chars.push_back(nullptr);
-
::execvp(kFfmpegExecutable, args_chars.data());
// exec only returns on error.
const int error = errno;
@@ -308,13 +318,12 @@ Status WriteFile(const string& filename, StringPiece contents) {
Status ReadAudioFile(const string& filename, const string& audio_format_id,
int32 samples_per_second, int32 channel_count,
- std::vector<float>* output_samples) {
+ const string& stream, std::vector<float>* output_samples) {
// Create an argument list.
string output_filename = io::GetTempFilename("raw");
const std::vector<string> args =
FfmpegAudioCommandLine(filename, output_filename, audio_format_id,
- samples_per_second, channel_count);
-
+ samples_per_second, channel_count, stream);
// Unfortunately, it's impossible to differentiate an exec failure due to the
// binary being missing and an error from the binary's execution. Therefore,
// check to see if the binary *should* be available. If not, return an error
@@ -368,7 +377,6 @@ Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data,
// Create an argument list.
const std::vector<string> args =
FfmpegVideoCommandLine(filename, output_filename);
-
// Execute ffmpeg and report errors.
pid_t child_pid = ::fork();
if (child_pid < 0) {
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h
index c5ea1432bf..a8d5a0dd83 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
+#ifndef TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
+#define TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
#include <string>
#include <vector>
@@ -42,7 +42,7 @@ Status WriteFile(const string& filename, tensorflow::StringPiece contents);
// contain a separate sample for each channel. Frames are ordered by time.
Status ReadAudioFile(const string& filename, const string& audio_format_id,
int32 samples_per_second, int32 channel_count,
- std::vector<float>* output_samples);
+ const string& stream, std::vector<float>* output_samples);
// Creates an audio file using ffmpeg in a specific format. The samples are in
// [-1.0, 1.0]. If there are multiple channels in the audio then each frame will
@@ -61,4 +61,4 @@ Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data,
} // namespace ffmpeg
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_
+#endif // TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index 08b5a6ea48..020b5c99c6 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -31,7 +31,7 @@ _ffmpeg_so = loader.load_op_library(
def decode_audio(contents, file_format=None, samples_per_second=None,
- channel_count=None):
+ channel_count=None, stream=None):
"""Create an op that decodes the contents of an audio file.
Note that ffmpeg is free to select the "best" audio track from an mp4.
@@ -51,6 +51,9 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
`contents` have more than this number, then some channels will
be merged or dropped. If `contents` has fewer than this, then
additional channels will be created from the existing ones.
+ stream: A string specifying which stream from the content file
+ should be decoded, e.g., '0' means the 0-th stream.
+ The default value is '' which leaves the decision to ffmpeg.
Returns:
A rank-2 tensor that has time along dimension 0 and channels along
@@ -61,7 +64,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
"""
return gen_decode_audio_op_py.decode_audio_v2(
contents, file_format=file_format, samples_per_second=samples_per_second,
- channel_count=channel_count)
+ channel_count=channel_count, stream=stream)
ops.NotDifferentiable('DecodeAudio')
diff --git a/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4
new file mode 100644
index 0000000000..2485da86d6
--- /dev/null
+++ b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4
Binary files differ
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 2effe8eb26..8cdb340f2d 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -77,6 +78,7 @@ class AssertScalarIntTest(test.TestCase):
[3, 4], dtype=dtypes.int32))
+@test_util.with_c_api
class WithShapeTest(test.TestCase):
def _assert_with_shape(self, tensor, expected_value, expected_shape,
@@ -213,16 +215,25 @@ class WithShapeTest(test.TestCase):
tensor_partial_shape.set_shape([None, 2])
for incompatible_shape in [[0], [1]]:
+ if ops._USE_C_API:
+ error_message = "Shapes must be equal rank, but are 2 and 1"
+ else:
+ error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible"
self.assertRaisesRegexp(
- ValueError, r"Shapes \(\?, 2\) and \([01],\) are not compatible",
+ ValueError, error_message,
tensor_util.with_shape, incompatible_shape, tensor_partial_shape)
for incompatible_shape in [[1, 2, 1]]:
self.assertRaisesRegexp(ValueError, "Dimensions must be equal",
tensor_util.with_shape, incompatible_shape,
tensor_partial_shape)
for incompatible_shape in [[2, 1]]:
+ if ops._USE_C_API:
+ error_message = (r"Dimension 1 in both shapes must be equal, but are "
+ r"2 and 1. Shapes are \[\?,2\] and \[2,1\].")
+ else:
+ error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible"
self.assertRaisesRegexp(
- ValueError, r"Shapes \(\?, 2\) and \(2, 1\) are not compatible",
+ ValueError, error_message,
tensor_util.with_shape, incompatible_shape, tensor_partial_shape)
compatible_shape = [2, 2]
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
index fa7a3c03aa..ba52697679 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
+#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
+#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
#if GOOGLE_CUDA
@@ -72,4 +72,4 @@ class FusedConvParameters : public ConvParameters {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
+#endif // TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 6a56237f67..bafd1d5941 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -25,13 +25,6 @@ limitations under the License.
namespace tensorflow {
-namespace {
-// Return the string containing the list of valid activation modes, that can be
-// used as an Attr() in REGISTER_OP.
-string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
-
-} // namespace
-
// --------------------------------------------------------------------------
// TODO(pauldonnelly): Add support for double inputs and scales to this Op,
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index b355a79b1a..5db34f0f8d 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -177,6 +177,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":losses_impl",
+ ":namedtuples",
"//tensorflow/python:util",
],
)
@@ -188,6 +189,9 @@ py_test(
deps = [
":tuple_losses",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
@@ -395,6 +399,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":eval_utils",
+ ":namedtuples",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 508b4d20d8..74811ff409 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -48,6 +49,15 @@ def add_gan_model_image_summaries(gan_model, grid_size=4):
Raises:
ValueError: If real and generated data aren't images.
"""
+ if isinstance(gan_model, namedtuples.CycleGANModel):
+ saved_params = locals()
+ saved_params.pop('gan_model', None)
+ with ops.name_scope('cyclegan_x2y_image_summaries'):
+ add_gan_model_image_summaries(gan_model.model_x2y, **saved_params)
+ with ops.name_scope('cyclegan_y2x_image_summaries'):
+ add_gan_model_image_summaries(gan_model.model_y2x, **saved_params)
+ return
+
_assert_is_image(gan_model.real_data)
_assert_is_image(gan_model.generated_data)
@@ -96,6 +106,15 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
ValueError: If the generator input, real, and generated data aren't all the
same size.
"""
+ if isinstance(gan_model, namedtuples.CycleGANModel):
+ saved_params = locals()
+ saved_params.pop('gan_model', None)
+ with ops.name_scope('cyclegan_x2y_image_comparison_summaries'):
+ add_image_comparison_summaries(gan_model.model_x2y, **saved_params)
+ with ops.name_scope('cyclegan_y2x_image_comparison_summaries'):
+ add_image_comparison_summaries(gan_model.model_y2x, **saved_params)
+ return
+
_assert_is_image(gan_model.generator_inputs)
_assert_is_image(gan_model.generated_data)
_assert_is_image(gan_model.real_data)
@@ -133,6 +152,13 @@ def add_gan_model_summaries(gan_model):
Args:
gan_model: A GANModel tuple.
"""
+ if isinstance(gan_model, namedtuples.CycleGANModel):
+ with ops.name_scope('cyclegan_x2y_summaries'):
+ add_gan_model_summaries(gan_model.model_x2y)
+ with ops.name_scope('cyclegan_y2x_summaries'):
+ add_gan_model_summaries(gan_model.model_y2x)
+ return
+
with ops.name_scope('generator_variables'):
for var in gan_model.generator_variables:
summary.histogram(var.name, var)
@@ -147,6 +173,13 @@ def add_regularization_loss_summaries(gan_model):
Args:
gan_model: A GANModel tuple.
"""
+ if isinstance(gan_model, namedtuples.CycleGANModel):
+ with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'):
+ add_regularization_loss_summaries(gan_model.model_x2y)
+ with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'):
+ add_regularization_loss_summaries(gan_model.model_y2x)
+ return
+
if gan_model.generator_scope:
summary.scalar(
'generator_regularization_loss',
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index a3b02bcefc..a02d8772e1 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -57,40 +57,83 @@ def get_gan_model():
discriminator_fn=discriminator_model)
+def get_cyclegan_model():
+ with variable_scope.variable_scope('x2y'):
+ model_x2y = get_gan_model()
+ with variable_scope.variable_scope('y2x'):
+ model_y2x = get_gan_model()
+ return namedtuples.CycleGANModel(
+ model_x2y=model_x2y,
+ model_y2x=model_y2x,
+ reconstructed_x=array_ops.zeros([3, 30, 35, 6]),
+ reconstructed_y=array_ops.zeros([3, 30, 35, 6]))
+
+
class SummariesTest(test.TestCase):
- def testAddGanModelImageSummaries(self):
- summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2)
+ def _test_add_gan_model_image_summaries_impl(self, get_model_fn,
+ expected_num_summary_ops):
+ summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2)
- self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ self.assertEquals(expected_num_summary_ops,
+ len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
variables.global_variables_initializer().run()
summary.merge_all().eval()
- def testAddGanModelSummaries(self):
- summaries.add_gan_model_summaries(get_gan_model())
+ def test_add_gan_model_image_summaries(self):
+ self._test_add_gan_model_image_summaries_impl(get_gan_model, 5)
+
+ def test_add_gan_model_image_summaries_for_cyclegan(self):
+ self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10)
- self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ def _test_add_gan_model_summaries_impl(self, get_model_fn,
+ expected_num_summary_ops):
+ summaries.add_gan_model_summaries(get_model_fn())
+
+ self.assertEquals(expected_num_summary_ops,
+ len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
variables.global_variables_initializer().run()
summary.merge_all().eval()
- def testAddRegularizationLossSummaries(self):
- summaries.add_regularization_loss_summaries(get_gan_model())
+ def test_add_gan_model_summaries(self):
+ self._test_add_gan_model_summaries_impl(get_gan_model, 3)
+
+ def test_add_gan_model_summaries_for_cyclegan(self):
+ self._test_add_gan_model_summaries_impl(get_cyclegan_model, 6)
- self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ def _test_add_regularization_loss_summaries_impl(self, get_model_fn,
+ expected_num_summary_ops):
+ summaries.add_regularization_loss_summaries(get_model_fn())
+
+ self.assertEquals(expected_num_summary_ops,
+ len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
summary.merge_all().eval()
+ def test_add_regularization_loss_summaries(self):
+ self._test_add_regularization_loss_summaries_impl(get_gan_model, 2)
+
+ def test_add_regularization_loss_summaries_for_cyclegan(self):
+ self._test_add_regularization_loss_summaries_impl(get_cyclegan_model, 4)
+
# TODO(joelshor): Add correctness test.
- def testAddImageComparisonSummaries(self):
- summaries.add_image_comparison_summaries(
- get_gan_model(), display_diffs=True)
+ def _test_add_image_comparison_summaries_impl(self, get_model_fn,
+ expected_num_summary_ops):
+ summaries.add_image_comparison_summaries(get_model_fn(), display_diffs=True)
- self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ self.assertEquals(expected_num_summary_ops,
+ len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
with self.test_session(use_gpu=True):
summary.merge_all().eval()
+ def test_add_image_comparison_summaries(self):
+ self._test_add_image_comparison_summaries_impl(get_gan_model, 1)
+
+ def test_add_image_comparison_summaries_for_cyclegan(self):
+ self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index 940762cf2a..23a3b60cc0 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -67,6 +67,7 @@ __all__ = [
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
+ 'cycle_consistency_loss',
]
@@ -915,3 +916,63 @@ def combine_adversarial_loss(main_loss,
array_ops.stop_gradient(adv_coeff) * adversarial_loss)
return final_loss
+
+
+def cycle_consistency_loss(data_x,
+ reconstructed_data_x,
+ data_y,
+ reconstructed_data_y,
+ scope=None,
+ add_summaries=False):
+ """Defines the cycle consistency loss.
+
+ The cyclegan model has two partial models where `model_x2y` generator F maps
+ data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x`
+ in data set X, we could reconstruct it by
+ * reconstructed_data_x = G(F(data_x))
+ Similarly
+ * reconstructed_data_y = F(G(data_y))
+
+ The cycle consistency loss is about the difference between data and
+ reconstructed data, namely
+ * loss_x2x = |data_x - G(F(data_x))| (L1-norm)
+ * loss_y2y = |data_y - F(G(data_y))| (L1-norm)
+ * loss = (loss_x2x + loss_y2y) / 2
+ where `loss` is the final result.
+
+ See https://arxiv.org/abs/1703.10593 for more details.
+
+ Args:
+ data_x: A `Tensor` of data X.
+ reconstructed_data_x: A `Tensor` of reconstructed data X.
+ data_y: A `Tensor` of data Y.
+ reconstructed_data_y: A `Tensor` of reconstructed data Y.
+ scope: The scope for the operations performed in computing the loss.
+ Defaults to None.
+ add_summaries: Whether or not to add detailed summaries for the loss.
+ Defaults to False.
+
+ Returns:
+ A scalar `Tensor` of cycle consistency loss.
+ """
+
+ def _partial_cycle_consistency_loss(data, reconstructed_data):
+ # Following the original implementation
+ # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua
+ # use L1-norm of pixel-wise error normalized by data size so that
+ # `cycle_loss_weight` can be specified independent of image size.
+ return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data))
+
+ with ops.name_scope(
+ scope,
+ 'cycle_consistency_loss',
+ values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]):
+ loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x)
+ loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y)
+ loss = (loss_x2x + loss_y2y) / 2.0
+ if add_summaries:
+ summary.scalar('cycle_consistency_loss_x2x', loss_x2x)
+ summary.scalar('cycle_consistency_loss_y2y', loss_y2y)
+ summary.scalar('cycle_consistency_loss', loss)
+
+ return loss
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index b5cd8c92ba..56ac45554d 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -620,7 +620,34 @@ class CombineAdversarialLossTest(test.TestCase):
with self.test_session(use_gpu=True) as sess:
for _ in range(10): # spot check closeness on more than one sample.
gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
- self.assertNear(gnorm_np, precond_gnorm_np, 1e-5)
+ self.assertNear(gnorm_np, precond_gnorm_np, 1e-4)
+
+
+class CycleConsistencyLossTest(test.TestCase):
+ """Tests for cycle_consistency_loss."""
+
+ def setUp(self):
+ super(CycleConsistencyLossTest, self).setUp()
+
+ self._data_x_np = [[1.0, 2, 3], [4, 5, 6]]
+ self._reconstructed_data_x_np = [[7.0, 8, 9], [10, 11, 12]]
+ self._data_y_np = [1.0, 9]
+ self._reconstructed_data_y_np = [-2.0, 3]
+
+ self._data_x = constant_op.constant(self._data_x_np, dtype=dtypes.float32)
+ self._reconstructed_data_x = constant_op.constant(
+ self._reconstructed_data_x_np, dtype=dtypes.float32)
+ self._data_y = constant_op.constant(self._data_y_np, dtype=dtypes.float32)
+ self._reconstructed_data_y = constant_op.constant(
+ self._reconstructed_data_y_np, dtype=dtypes.float32)
+
+ def test_correct_loss(self):
+ loss = tfgan_losses.cycle_consistency_loss(
+ self._data_x, self._reconstructed_data_x, self._data_y,
+ self._reconstructed_data_y)
+ with self.test_session(use_gpu=True):
+ variables.global_variables_initializer().run()
+ self.assertNear(5.25, loss.eval(), 1e-5)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
index b341f03a0d..dcc3f94c2d 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
@@ -60,6 +60,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.losses.python import losses_impl
from tensorflow.python.util import tf_inspect
@@ -78,6 +79,7 @@ __all__ = [
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
+ 'cycle_consistency_loss',
]
@@ -246,3 +248,32 @@ def combine_adversarial_loss(gan_loss,
scalar_summaries,
gradient_summaries)
return gan_loss._replace(generator_loss=combined_loss)
+
+
+def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False):
+ """Defines the cycle consistency loss.
+
+ Uses `cycle_consistency_loss` to compute the cycle consistency loss for a
+ `cyclegan_model`.
+
+ Args:
+ cyclegan_model: A `CycleGANModel` namedtuple.
+ scope: The scope for the operations performed in computing the loss.
+ Defaults to None.
+ add_summaries: Whether or not to add detailed summaries for the loss.
+ Defaults to False.
+
+ Returns:
+ A scalar `Tensor` of cycle consistency loss.
+
+ Raises:
+ ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple.
+ """
+ if not isinstance(cyclegan_model, namedtuples.CycleGANModel):
+ raise ValueError(
+ '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' %
+ type(cyclegan_model))
+ return losses_impl.cycle_consistency_loss(
+ cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x,
+ cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y,
+ scope, add_summaries)
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index 215b15ef69..aa1ef11172 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -22,8 +22,11 @@ import collections
import numpy as np
+from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses
-
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -125,6 +128,7 @@ manual_tests = [
'combine_adversarial_loss',
'mutual_information_penalty',
'wasserstein_gradient_penalty',
+ 'cycle_consistency_loss',
]
discriminator_keyword_args = {
@@ -139,6 +143,38 @@ generator_keyword_args = {
}
+class CycleConsistencyLossTest(test.TestCase):
+
+ def setUp(self):
+ super(CycleConsistencyLossTest, self).setUp()
+
+ def _partial_model(generator_inputs_np):
+ model = namedtuples.GANModel(*[None] * 11)
+ return model._replace(
+ generator_inputs=constant_op.constant(
+ generator_inputs_np, dtype=dtypes.float32))
+
+ self._model_x2y = _partial_model([1, 2])
+ self._model_y2x = _partial_model([5, 6])
+
+ def test_model_type(self):
+ """Test the input model type for `cycle_consistency_loss`."""
+ with self.assertRaises(ValueError):
+ tfgan_losses.cycle_consistency_loss(self._model_x2y)
+
+ def test_correct_loss(self):
+ """Test the output of `cycle_consistency_loss`."""
+ loss = tfgan_losses.cycle_consistency_loss(
+ namedtuples.CycleGANModel(
+ model_x2y=self._model_x2y,
+ model_y2x=self._model_y2x,
+ reconstructed_x=constant_op.constant([9, 8], dtype=dtypes.float32),
+ reconstructed_y=constant_op.constant([7, 2], dtype=dtypes.float32)))
+ with self.test_session(use_gpu=True):
+ variables.global_variables_initializer().run()
+ self.assertNear(5.0, loss.eval(), 1e-5)
+
+
if __name__ == '__main__':
for loss_name in tfgan_losses.__all__:
if loss_name in manual_tests: continue
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index 3d4e315ebd..25cfeafeec 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -30,7 +30,9 @@ __all__ = [
'GANModel',
'InfoGANModel',
'ACGANModel',
+ 'CycleGANModel',
'GANLoss',
+ 'CycleGANLoss',
'GANTrainOps',
'GANTrainSteps',
]
@@ -115,6 +117,25 @@ class ACGANModel(
"""
+class CycleGANModel(
+ collections.namedtuple(
+ 'CycleGANModel',
+ ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))):
+ """An CycleGANModel contains all the pieces needed for CycleGAN training.
+
+ The model `model_x2y` generator F maps data set X to Y, while the model
+ `model_y2x` generator G maps data set Y to X.
+
+ See https://arxiv.org/abs/1703.10593 for more details.
+
+ Args:
+ model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y.
+ model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X.
+ reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)).
+ reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)).
+ """
+
+
class GANLoss(
collections.namedtuple('GANLoss', (
'generator_loss',
@@ -128,6 +149,18 @@ class GANLoss(
"""
+class CycleGANLoss(
+ collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))):
+ """CycleGANLoss contains the losses for `CycleGANModel`.
+
+ See https://arxiv.org/abs/1703.10593 for more details.
+
+ Args:
+ loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`.
+ loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`.
+ """
+
+
class GANTrainOps(
collections.namedtuple('GANTrainOps', (
'generator_train_op',
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index c429ec4831..5d0ac93aec 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -52,7 +52,9 @@ __all__ = [
'gan_model',
'infogan_model',
'acgan_model',
+ 'cyclegan_model',
'gan_loss',
+ 'cyclegan_loss',
'gan_train_ops',
'gan_train',
'get_sequential_train_hooks',
@@ -277,14 +279,16 @@ def acgan_model(
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
- (discriminator_gen_outputs, discriminator_gen_classification_logits
- ) = _validate_acgan_discriminator_outputs(
- discriminator_fn(generated_data, generator_inputs))
+ with ops.name_scope(dis_scope.name+'/generated/'):
+ (discriminator_gen_outputs, discriminator_gen_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
- real_data = ops.convert_to_tensor(real_data)
- (discriminator_real_outputs, discriminator_real_classification_logits
- ) = _validate_acgan_discriminator_outputs(
- discriminator_fn(real_data, generator_inputs))
+ with ops.name_scope(dis_scope.name+'/real/'):
+ real_data = ops.convert_to_tensor(real_data)
+ (discriminator_real_outputs, discriminator_real_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(real_data, generator_inputs))
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
@@ -305,6 +309,76 @@ def acgan_model(
discriminator_gen_classification_logits)
+def cyclegan_model(
+ # Lambdas defining models.
+ generator_fn,
+ discriminator_fn,
+ # data X and Y.
+ data_x,
+ data_y,
+ # Optional scopes.
+ generator_scope='Generator',
+ discriminator_scope='Discriminator',
+ model_x2y_scope='ModelX2Y',
+ model_y2x_scope='ModelY2X',
+ # Options.
+ check_shapes=True):
+ """Returns a CycleGAN model outputs and variables.
+
+ See https://arxiv.org/abs/1703.10593 for more details.
+
+ Args:
+ generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
+ returns the outputs of the GAN generator.
+ discriminator_fn: A python lambda that takes `real_data`/`generated data`
+ and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
+ data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
+ data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created. Defaults to 'Generator'.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created. Defaults to
+ 'Discriminator'.
+ model_x2y_scope: Optional variable scope for model x2y variables. Defaults
+ to 'ModelX2Y'.
+ model_y2x_scope: Optional variable scope for model y2x variables. Defaults
+ to 'ModelY2X'.
+ check_shapes: If `True`, check that generator produces Tensors that are the
+ same shape as `data_x` (`data_y`). Otherwise, skip this check.
+
+ Returns:
+ A `CycleGANModel` namedtuple.
+
+ Raises:
+ ValueError: If `check_shapes` is True and `data_x` or the generator output
+ does not have the same shape as `data_y`.
+ """
+
+ # Create models.
+ def _define_partial_model(input_data, output_data):
+ return gan_model(
+ generator_fn=generator_fn,
+ discriminator_fn=discriminator_fn,
+ real_data=output_data,
+ generator_inputs=input_data,
+ generator_scope=generator_scope,
+ discriminator_scope=discriminator_scope,
+ check_shapes=check_shapes)
+
+ with variable_scope.variable_scope(model_x2y_scope):
+ model_x2y = _define_partial_model(data_x, data_y)
+ with variable_scope.variable_scope(model_y2x_scope):
+ model_y2x = _define_partial_model(data_y, data_x)
+
+ with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True):
+ reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
+ with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True):
+ reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)
+
+ return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
+ reconstructed_y)
+
+
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
if isinstance(aux_loss_weight, ops.Tensor):
aux_loss_weight.shape.assert_is_compatible_with([])
@@ -494,6 +568,69 @@ def gan_loss(
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
+def cyclegan_loss(
+ model,
+ # Loss functions.
+ generator_loss_fn=tfgan_losses.least_squares_generator_loss,
+ discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss,
+ # Auxiliary losses.
+ cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss,
+ cycle_consistency_loss_weight=10.0,
+ # Options
+ **kwargs):
+ """Returns the losses for a `CycleGANModel`.
+
+ See https://arxiv.org/abs/1703.10593 for more details.
+
+ Args:
+ model: A `CycleGANModel` namedtuple.
+ generator_loss_fn: The loss function on the generator. Takes a `GANModel`
+ named tuple.
+ discriminator_loss_fn: The loss function on the discriminator. Takes a
+ `GANModel` namedtuple.
+ cycle_consistency_loss_fn: The cycle consistency loss function. Takes a
+ `CycleGANModel` namedtuple.
+ cycle_consistency_loss_weight: A non-negative Python number or a scalar
+ `Tensor` indicating how much to weigh the cycle consistency loss.
+ **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss
+ for each partial model of `model`.
+
+ Returns:
+ A `CycleGANLoss` namedtuple.
+
+ Raises:
+ ValueError: If `model` is not a `CycleGANModel` namedtuple.
+ """
+ # Sanity checks.
+ if not isinstance(model, namedtuples.CycleGANModel):
+ raise ValueError(
+ '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model))
+
+ # Defines cycle consistency loss.
+ cycle_consistency_loss = cycle_consistency_loss_fn(
+ model, add_summaries=kwargs.get('add_summaries', True))
+ cycle_consistency_loss_weight = _validate_aux_loss_weight(
+ cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
+ aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
+
+ # Defines losses for each partial model.
+ def _partial_loss(partial_model):
+ partial_loss = gan_loss(
+ partial_model,
+ generator_loss_fn=generator_loss_fn,
+ discriminator_loss_fn=discriminator_loss_fn,
+ **kwargs)
+ return partial_loss._replace(
+ generator_loss=partial_loss.generator_loss + aux_loss)
+
+ with ops.name_scope('cyclegan_loss_x2y'):
+ loss_x2y = _partial_loss(model.model_x2y)
+ with ops.name_scope('cyclegan_loss_y2x'):
+ loss_y2x = _partial_loss(model.model_y2x)
+
+ return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
+
+
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.
@@ -561,6 +698,24 @@ def gan_train_ops(
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
be used to train a generator/discriminator pair.
"""
+ if isinstance(model, namedtuples.CycleGANModel):
+ saved_params = locals()
+ saved_params.pop('model', None)
+ saved_params.pop('loss', None)
+ kwargs = saved_params.pop('kwargs', {})
+ saved_params.update(kwargs)
+ with ops.name_scope('cyclegan_x2y_train'):
+ train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
+ **saved_params)
+ with ops.name_scope('cyclegan_y2x_train'):
+ train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
+ **saved_params)
+ return namedtuples.GANTrainOps(
+ (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
+ (train_ops_x2y.discriminator_train_op,
+ train_ops_y2x.discriminator_train_op),
+ training_util.get_or_create_global_step().assign_add(1))
+
# Create global step increment op.
global_step = training_util.get_or_create_global_step()
global_step_inc = global_step.assign_add(1)
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58704e6859..f9bdaa74c9 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -210,6 +210,38 @@ def create_callable_acgan_model():
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
+def get_cyclegan_model():
+ return namedtuples.CycleGANModel(
+ model_x2y=get_gan_model(),
+ model_y2x=get_gan_model(),
+ reconstructed_x=array_ops.ones([1, 2, 3]),
+ reconstructed_y=array_ops.zeros([1, 2, 3]))
+
+
+def get_callable_cyclegan_model():
+ return namedtuples.CycleGANModel(
+ model_x2y=get_callable_gan_model(),
+ model_y2x=get_callable_gan_model(),
+ reconstructed_x=array_ops.ones([1, 2, 3]),
+ reconstructed_y=array_ops.zeros([1, 2, 3]))
+
+
+def create_cyclegan_model():
+ return train.cyclegan_model(
+ generator_model,
+ discriminator_model,
+ data_x=array_ops.zeros([1, 2]),
+ data_y=array_ops.ones([1, 2]))
+
+
+def create_callable_cyclegan_model():
+ return train.cyclegan_model(
+ Generator(),
+ Discriminator(),
+ data_x=array_ops.zeros([1, 2]),
+ data_y=array_ops.ones([1, 2]))
+
+
def get_sync_optimizer():
return sync_replicas_optimizer.SyncReplicasOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
@@ -261,6 +293,13 @@ class GANModelTest(test.TestCase):
self._test_output_type_helper(
get_callable_acgan_model, namedtuples.ACGANModel)
+ def test_output_type_cyclegan(self):
+ self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel)
+
+ def test_output_type_callable_cyclegan(self):
+ self._test_output_type_helper(get_callable_cyclegan_model,
+ namedtuples.CycleGANModel)
+
def test_no_shape_check(self):
def dummy_generator_model(_):
return (None, None)
@@ -308,6 +347,17 @@ class GANLossTest(test.TestCase):
def test_output_type_callable_acgan(self):
self._test_output_type_helper(get_callable_acgan_model)
+ def test_output_type_cyclegan(self):
+ loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True)
+ self.assertIsInstance(loss, namedtuples.CycleGANLoss)
+ self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
+
+ def test_output_type_callable_cyclegan(self):
+ loss = train.cyclegan_loss(
+ create_callable_cyclegan_model(), add_summaries=True)
+ self.assertIsInstance(loss, namedtuples.CycleGANLoss)
+ self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
+
# Test gradient penalty option.
def _test_grad_penalty_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
@@ -431,6 +481,34 @@ class GANLossTest(test.TestCase):
def test_callable_acgan(self):
self._test_acgan_helper(create_callable_acgan_model)
+ # Test that CycleGan models work.
+ def _test_cyclegan_helper(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.cyclegan_loss(model)
+ self.assertIsInstance(loss, namedtuples.CycleGANLoss)
+
+ # Check values.
+ with self.test_session(use_gpu=True) as sess:
+ variables.global_variables_initializer().run()
+ (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np,
+ loss_y2x_dis_np) = sess.run([
+ loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss,
+ loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss
+ ])
+
+ self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np)
+ self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np)
+ self.assertTrue(np.isscalar(loss_x2y_gen_np))
+ self.assertTrue(np.isscalar(loss_x2y_dis_np))
+ self.assertTrue(np.isscalar(loss_y2x_gen_np))
+ self.assertTrue(np.isscalar(loss_y2x_dis_np))
+
+ def test_cyclegan(self):
+ self._test_cyclegan_helper(create_cyclegan_model)
+
+ def test_callable_cyclegan(self):
+ self._test_cyclegan_helper(create_callable_cyclegan_model)
+
def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2,
pool_size):
history_values = []
diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD
index bdbe6f0a72..707ae25d48 100644
--- a/tensorflow/contrib/gdr/BUILD
+++ b/tensorflow/contrib/gdr/BUILD
@@ -82,6 +82,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:graph_mgr",
+ "//tensorflow/core/distributed_runtime:recent_request_ids",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker",
"//tensorflow/core/distributed_runtime:worker_cache",
@@ -103,6 +104,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
index adef2aac33..28f68cec8c 100644
--- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@@ -47,6 +48,7 @@ class GdrRecvTensorCall : public BaseRecvTensorCall {
recv_args_(recv_args) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(key.data(), key.size());
+ req_.set_request_id(GetUniqueRequestId());
}
~GdrRecvTensorCall() override {}
diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc
index 5686412347..ce1d8d2d73 100644
--- a/tensorflow/contrib/gdr/gdr_worker.cc
+++ b/tensorflow/contrib/gdr/gdr_worker.cc
@@ -41,17 +41,26 @@ namespace tensorflow {
GdrWorker::GdrWorker(WorkerEnv* worker_env,
RemoteMemoryManager* remote_memory_manager)
- : GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {}
+ : GrpcWorker(worker_env),
+ remote_memory_manager_(remote_memory_manager),
+ recv_tensor_recent_request_ids_(100000) {}
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
+ Status s = recv_tensor_recent_request_ids_.TrackUnique(
+ request->request_id(), "RecvTensor (GdrWorker)", *request);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
- Status s = Rendezvous::ParseKey(key, &parsed);
+ s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h
index a30b7baaed..54081f655e 100644
--- a/tensorflow/contrib/gdr/gdr_worker.h
+++ b/tensorflow/contrib/gdr/gdr_worker.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
+#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
namespace tensorflow {
@@ -38,6 +39,7 @@ class GdrWorker : public GrpcWorker {
private:
RemoteMemoryManager* remote_memory_manager_; // Not owned
+ RecentRequestIds recv_tensor_recent_request_ids_;
};
} // namespace tensorflow
diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md
index 5a6f2f3086..cb3a1087de 100644
--- a/tensorflow/contrib/hvx/README.md
+++ b/tensorflow/contrib/hvx/README.md
@@ -1,60 +1,67 @@
# TensorFlow Runtime with HVX Acceleration
-## Description
+This README explain how to build and use the TensorFlow runtime with HVX Acceleration. HVX is an extension of Hexagon, a DSP provided by Qualcomm, which can compute vector calculations faster using less energy than ARM processors.
-This README explain how to build and use the TensorFlow Runtime with HVX Acceleration. HVX is an extension of Hexagon which is a DSP provided by qualcomm which can compute vector calculations faster using lower energy than ARM processors.
+## Dependencies
+
+* [Android SDK](https://developer.android.com/studio/index.html).
+* [Android NDK](https://developer.android.com/ndk/index.html). Save the path in `${NDK_ROOT}`.
+* A rooted Qualcomm-based Android device connected to the computer (preferably, a [Snapdragon Development Board](https://developer.qualcomm.com/hardware/additional-snapdragon), but it could be a rooted phone with a Qualcomm SoC, albeit this guide may not work with it). The device needs to be rooted for development and testing purposes, and shouldn't be needed in production. See [Behold, The Snapdragon MDP](https://developer.qualcomm.com/blog/behold-snapdragon-mdp) for more information.
+* [Hexagon SDK v3.0](https://developer.qualcomm.com/software/hexagon-dsp-sdk/tools). Save the path in `${QUALCOMM_SDK}`.
+* The current directory should be TensorFlow source code (`git clone https://github.com/tensorflow/tensorflow.git && cd tensorflow`), and saved into `${TF_ROOT_DIR}`.
+
+You may also need to add a test signature in the device to run HVX-based binaries. Follow the instructions in `${QUALCOMM_SDK}/docs/Tools_Signing.html`, using Python 2.
+
+Note that if the device is not rooted, you may not be able to get the serial number, push the test signature and/or run binary files that call HVX libraries.
## Quick Start Guide
-We provides several tools to build and run inference with this runtime quickly.
+We provide several tools to build and run inference with this runtime quickly.
-#### All-in-one script to run inception model with prebuild hexagon library
-If you don’t need to build your own implementation of hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries.
+### Run inception model with a prebuilt Hexagon library
+If you don’t need to build your own implementation of Hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries.
+
+```shell
+./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh -p
```
-git clone https://github.com/tensorflow/tensorflow.git
-cd tensorflow
-NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/build_all_android.sh -X
-```
-(-X downloads dependencies to hexagon HVX and graphs, and copy all dependencies to android and execute a test)
-#### All-in-one script to run inception model by building entire libraries from source code
- If you want to build your own implementation of hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads source and build everything for hexagon.
+The `-p` option makes the script download dependencies (i.e., Hexagon HVX binaries and graphs models), copy them to the Android device and execute a test.
-```
-git clone https://github.com/tensorflow/tensorflow.git
-cd tensorflow
-QUALCOMM_SDK="/path/to/qualcomm/sdk" NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
+### Run inception model by building all from the source code
+
+If you want to build your own implementation of Hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads the source and builds everything that's necessary.
+
+```shell
+./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
```
## Building libraries
If you've finished walking through the quick start guide, you may want to try building each binary manually.
-#### Build libhexagon_nn_skel.so
-Download hexagon nn library from codeaurora.org and build it.
+### Build libhexagon\_nn\_skel.so
-```
+Download Hexagon NN library from codeaurora.org and build it.
+
+```shell
git clone https://source.codeaurora.org/quic/hexagon_nn/nnlib
cd nnlib
```
-(Just follow instructions in README.HOW_TO_BUILD. You can find libhexagon_nn_skel.so in hexagon_Release_dynamic_toolv72_v60/ship)
-Then copy the generated binary to GEN_LIBS_DIR
+Just follow the instructions in `README.HOW_TO_BUILD`. You can find the file `libhexagon_nn_skel.so` in `hexagon_Release_dynamic_toolv72_v60/ship`.
+Then copy the generated binary to `${GEN_LIBS_DIR}`.
-```
+```shell
GEN_LIBS_DIR="/path/to/a/dir/to/store/hexagon/libraries"
cp -v "hexagon_Release_dynamic_toolv72_v60/ship/libhexagon_nn_skel.so" "${GEN_LIBS_DIR}"
```
-#### Build libhexagon_controller.so
+### Build libhexagon\_controller.so
+
Download tensorflow and build hexagon controller.
-```
-git clone https://github.com/tensorflow/tensorflow.git
-cd tensorflow
-TF_ROOT_DIR="$(pwd)"
-QUALCOMM_SDK="/path/to/qualcomm/sdk"
+```shell
GENERATED_NNLIB_DIRECTORY="/path/to/nnlib"
GENERATED_HEXAGON_CONTROLLER_DIRECTORY="${QUALCOMM_SDK}/examples/common/generated_hexagon_controller"
rm -rf "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}"
@@ -70,12 +77,12 @@ make tree VERBOSE=1 V=android_Release
cp -v "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}/android_Release/ship/libhexagon_controller.so" "${GEN_LIBS_DIR}"
```
-#### Build tensorflow linking hexagon library
-Build tensorflow with the build_all_android.sh with specifying -x option.
+### Build TensorFlow linking Hexagon library
-```
+Build TensorFlow with `build_all_android.sh` specifying the `-x` option.
+
+```shell
BUILD_ALL_ANDROID_PATH="${TF_ROOT_DIR}/tensorflow/contrib/makefile/build_all_android.sh"
-NDK_ROOT="/path/to/ndk/root"
CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \
-x "${GEN_LIBS_DIR}" \
@@ -83,11 +90,11 @@ CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \
-t hexagon_graph_execution
```
-#### Push binaries to your Android device
+### Push binaries to your Android device
Before running tests on your Android device, you need to push several binaries to it.
-```
+```shell
adb push "${GEN_LIBS_DIR}/libhexagon_controller.so" "/data/local/tmp"
adb push "${GEN_LIBS_DIR}/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp"
adb push -p \
@@ -100,40 +107,54 @@ adb shell chmod "${ANDROID_EXEC_FILE_MODE}" \
adb wait-for-device
```
-#### Run tests on the device
+### Run tests on the device
Finally, you can run the inference tests on your device.
-```
+```shell
adb shell 'LD_LIBRARY_PATH=/data/local/tmp:$LD_LIBRARY_PATH' \
"/data/local/tmp/hexagon_graph_execution"
```
-#### Troubleshooting
-If you're using the Open-Q 820 Snapdragon development kit, you may run into an issue with running the executable due to a missing testsig library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project.
+### Troubleshooting
+
+#### Testsig issue
+
+If you're using the Open-Q 820 Snapdragon Development Kit, you may run into an issue with running the executable due to a missing `testsig` library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project.
-If the lack of a testsig library is your problem, you will see errors of the type:
+If the lack of a `testsig` library is your problem, you will see errors of the type:
`vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:169::error: -1: 0 == (nErr = remotectl_open(name, (int*)ph, dlerrstr, sizeof(dlerrstr), &dlerr))`
-appearing in adb logcat.
-
-There are several ways to create the testsig library, the only prerequisite is Python and the correct version of the Hexagon-SDK. The following steps is one way to create this library:
-1. Run adb as root: `adb root`
-2. Run the command `adb shell cat /sys/devices/soc0/serial_number`
-3. Convert the decimal number you get as output to hex
-4. Run the python script: `python ${QUALCOMM_SDK}/tools/elfsigner/elfsigner.py -t $(SERIAL_NUMBER_HEX_VALUE)`
-5. The output of the python script is a shared library stored in ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so
-6. Push the shared library to your device:
+appearing in `adb logcat` or ["Expected: (version) >= (1), actual: 0 vs 1" while running a binary from adb](https://github.com/tensorflow/tensorflow/issues/11210).
+
+You need to add a test signature, as described at the beginning of this README. After rebooting your device, you should be able to run the sample application.
+
+#### Qualcomm SDK Linux installation fails with "Malformed \uxxxx encoding"
+
+The installation file is based on LaunchAnywhere, which fails in Linux if the `PS1` env variable contains non-common Unicode chars:
+
```
-adb root
-adb wait-for-device
-adb remount
-adb wait-for-device
-adb shell mkdir /system/lib/rfsa
-adb shell mkdir /system/lib/rfsa/adsp
-adb push ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so /system/lib/rfsa/adsp/
+Preparing to install...
+Extracting the JRE from the installer archive...
+Unpacking the JRE...
+Extracting the installation resources from the installer archive...
+Configuring the installer for this system's environment...
+
+Launching installer...
+
+An internal LaunchAnywhere application error has occured and this application cannot proceed. (LAX)
+
+Stack Trace:
+java.lang.IllegalArgumentException: Malformed \uxxxx encoding.
+ at java.util.Properties.loadConvert(Properties.java:574)
+ at java.util.Properties.load0(Properties.java:391)
+ at java.util.Properties.load(Properties.java:317)
+ at com.zerog.common.java.util.PropertiesUtil.loadProperties(Unknown Source)
+ at com.zerog.lax.LAX.<init>(Unknown Source)
+ at com.zerog.lax.LAX.main(Unknown Source)
```
-After rebooting your device, you should be able to run the sample application.
+It can be solved by temporarily assigning the `PS1` environment variable to something simple, such as '$'.
+
+## Maintainers
-Maintainers:
-- Satoshi Kataoka (satok@google.com, github.com/satok16)
+* Satoshi Kataoka (satok@google.com, github.com/satok16)
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h
index 194ae2ba47..8968da6d82 100644
--- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
+#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
+#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -84,4 +84,4 @@ struct AdjustHsvInYiqGPU {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
+#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
index 63f45ea55b..ae787b6f1a 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -113,6 +113,42 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
self.assertListEqual(loss.input_minibatches, tower_logits)
self.assertEqual(loss.num_registered_minibatches, num_towers)
+ def testMultiplyFisherSingleVector(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([1., 2., 3.])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ # the LossFunction.multiply_fisher docstring only says it supports the
+ # case where the vector is the same shape as the input natural parameters
+ # (i.e. the logits here), but here we also test leading dimensions
+ vector = np.array([1., 2., 3.])
+ vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
+
+ probs = np.exp(logits - np.logaddexp.reduce(logits))
+ fisher = np.diag(probs) - np.outer(probs, probs)
+
+ for vector in vectors:
+ result = loss.multiply_fisher(vector)
+ expected_result = np.dot(vector, fisher)
+ self.assertAllClose(expected_result, sess.run(result))
+
+ def testMultiplyFisherBatch(self):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.array([[1., 2., 3.], [4., 6., 8.]])
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+
+ vector = np.array([[1., 2., 3.], [5., 3., 1.]])
+
+ na = np.newaxis
+ probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
+ keepdims=True))
+ fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
+
+ result = loss.multiply_fisher(vector)
+ expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
+ self.assertEqual(sess.run(result).shape, logits.shape)
+ self.assertAllClose(expected_result, sess.run(result))
+
class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index 2daead2a71..cb3e698b9c 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -660,19 +660,20 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
def multiply_fisher(self, vector):
probs = self._probs
- return vector * probs - math_ops.reduce_sum(vector * probs, axis=1) * probs
+ return vector * probs - probs * math_ops.reduce_sum(
+ vector * probs, axis=-1, keep_dims=True)
def multiply_fisher_factor(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - probs * math_ops.reduce_sum(
- sqrt_probs * vector, axis=1, keep_dims=True)
+ sqrt_probs * vector, axis=-1, keep_dims=True)
def multiply_fisher_factor_transpose(self, vector):
probs = self._probs
sqrt_probs = self._sqrt_probs
return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
- probs * vector, axis=1, keep_dims=True)
+ probs * vector, axis=-1, keep_dims=True)
def multiply_fisher_factor_replicated_one_hot(self, index):
assert len(index) == 1, "Length of index was {}".format(len(index))
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
index 1f4a3ef568..e70b492374 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py
@@ -225,7 +225,7 @@ class LabeledTensorTest(test_util.Base):
tensor = array_ops.placeholder(dtypes.string, [None])
actual = core.LabeledTensor(tensor, ['x'])
self.assertIsNone(actual.axes['x'].size)
- self.assertIs(actual.axes['x'].value, tensor.get_shape()[0])
+ self.assertIsNone(actual.axes['x'].value.value)
def test_eq(self):
self.assertEqual(self.lt, self.lt)
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 6c624929f2..ef419862b4 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -27,6 +27,7 @@ See the @{$python/contrib.layers} guide.
@@convolution2d_transpose
@@conv3d_transpose
@@convolution3d_transpose
+@@dense_to_sparse
@@dropout
@@elu
@@embedding_lookup_unique
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index 2eaea23177..fc8f153fe3 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -221,8 +221,8 @@ class FeatureColumnTest(test.TestCase):
weighted_sparse_col = fc.weighted_sparse_column(ids, "weights")
self.assertEqual(weighted_sparse_col.name, "ids_weighted_by_weights")
- b = fc.shared_embedding_columns([sparse_col, weighted_sparse_col],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [sparse_col, weighted_sparse_col], dimension=4, combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(b[0].shared_embedding_name,
"a1_ids_weighted_by_weights_shared_embedding")
@@ -230,8 +230,8 @@ class FeatureColumnTest(test.TestCase):
"a1_ids_weighted_by_weights_shared_embedding")
# Tries reversing order to check compatibility condition.
- b = fc.shared_embedding_columns([weighted_sparse_col, sparse_col],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [weighted_sparse_col, sparse_col], dimension=4, combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(b[0].shared_embedding_name,
"a1_ids_weighted_by_weights_shared_embedding")
@@ -240,18 +240,17 @@ class FeatureColumnTest(test.TestCase):
# Tries adding two weighted columns to check compatibility between them.
weighted_sparse_col_2 = fc.weighted_sparse_column(ids, "weights_2")
- b = fc.shared_embedding_columns([weighted_sparse_col,
- weighted_sparse_col_2],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [weighted_sparse_col, weighted_sparse_col_2],
+ dimension=4,
+ combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(
b[0].shared_embedding_name,
- "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding"
- )
+ "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding")
self.assertEqual(
b[1].shared_embedding_name,
- "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding"
- )
+ "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding")
def testSharedEmbeddingColumnDeterminism(self):
# Tests determinism in auto-generated shared_embedding_name.
@@ -286,10 +285,10 @@ class FeatureColumnTest(test.TestCase):
columns = fc.shared_embedding_columns(
[a1, a2], dimension=4, combiner="mean")
columns_copy = copy.deepcopy(columns)
- self.assertEqual(
- columns_copy[0].shared_embedding_name, "a1_a2_shared_embedding")
- self.assertEqual(
- columns_copy[1].shared_embedding_name, "a1_a2_shared_embedding")
+ self.assertEqual(columns_copy[0].shared_embedding_name,
+ "a1_a2_shared_embedding")
+ self.assertEqual(columns_copy[1].shared_embedding_name,
+ "a1_a2_shared_embedding")
def testOneHotColumn(self):
a = fc.sparse_column_with_keys("a", ["a", "b", "c", "d"])
@@ -336,11 +335,11 @@ class FeatureColumnTest(test.TestCase):
weighted_ids = fc.weighted_sparse_column(ids, "weights")
one_hot = fc.one_hot_column(weighted_ids)
features = {
- 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]),
- 'weights': constant_op.constant([[2., 4., 6.]])
+ "ids": constant_op.constant([["marlo", "unknown", "omar"]]),
+ "weights": constant_op.constant([[2., 4., 6.]])
}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
- features, [one_hot])
+ features, [one_hot])
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
@@ -349,11 +348,9 @@ class FeatureColumnTest(test.TestCase):
def testMissingValueInOneHotColumnForSparseColumnWithKeys(self):
ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"])
one_hot = fc.one_hot_column(ids)
- features = {
- 'ids': constant_op.constant([['marlo', 'unknown', 'omar']])
- }
+ features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
- features, [one_hot])
+ features, [one_hot])
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
@@ -379,8 +376,7 @@ class FeatureColumnTest(test.TestCase):
self.assertEqual(d4.default_value, None)
self.assertEqual(d4.is_sparse, True)
# Default value is a list but dimension is None.
- with self.assertRaisesRegexp(ValueError,
- "Only scalar default value.*"):
+ with self.assertRaisesRegexp(ValueError, "Only scalar default value.*"):
fc._real_valued_var_len_column("g5", default_value=[2., 3.])
def testRealValuedVarLenColumnDtypes(self):
@@ -390,18 +386,19 @@ class FeatureColumnTest(test.TestCase):
"rvc": parsing_ops.VarLenFeature(dtype=dtypes.float32)
}, rvc.config)
- rvc = fc._real_valued_var_len_column("rvc", default_value=0,
- is_sparse=False)
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenSequenceFeature(shape=[],
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=0.0)
- }, rvc.config)
-
- rvc = fc._real_valued_var_len_column("rvc", dtype=dtypes.int32,
- default_value=0, is_sparse=True)
+ rvc = fc._real_valued_var_len_column(
+ "rvc", default_value=0, is_sparse=False)
+ self.assertDictEqual({
+ "rvc":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=0.0)
+ }, rvc.config)
+
+ rvc = fc._real_valued_var_len_column(
+ "rvc", dtype=dtypes.int32, default_value=0, is_sparse=True)
self.assertDictEqual(
{
"rvc": parsing_ops.VarLenFeature(dtype=dtypes.int32)
@@ -409,8 +406,8 @@ class FeatureColumnTest(test.TestCase):
with self.assertRaisesRegexp(TypeError,
"dtype must be convertible to float"):
- fc._real_valued_var_len_column("rvc", dtype=dtypes.string,
- default_value="", is_sparse=True)
+ fc._real_valued_var_len_column(
+ "rvc", dtype=dtypes.string, default_value="", is_sparse=True)
def testRealValuedColumn(self):
a = fc.real_valued_column("aaa")
@@ -504,13 +501,13 @@ class FeatureColumnTest(test.TestCase):
for output_rank in range(1, 3 + len(dimensions)):
with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
real_valued_output = real_valued_column._to_dnn_input_layer(
- constant_op.constant(
- real_valued_input, dtype=dtypes.float32),
+ constant_op.constant(real_valued_input, dtype=dtypes.float32),
output_rank=output_rank)
with self.test_session() as sess:
real_valued_eval = sess.run(real_valued_output)
- expected_shape = (input_shape[:output_rank - 1] +
- [np.prod(input_shape[output_rank - 1:])])
+ expected_shape = (
+ input_shape[:output_rank - 1] +
+ [np.prod(input_shape[output_rank - 1:])])
self.assertEquals(expected_shape, list(real_valued_eval.shape))
def testRealValuedColumnDensification(self):
@@ -520,8 +517,7 @@ class FeatureColumnTest(test.TestCase):
"sparse_real_valued1", is_sparse=True)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=[2.0, 5.0], indices=[[0, 0], [2, 0]], dense_shape=[3, 1])
- with self.assertRaisesRegexp(
- ValueError, "Set is_sparse to False"):
+ with self.assertRaisesRegexp(ValueError, "Set is_sparse to False"):
real_valued_column._to_dnn_input_layer(sparse_tensor)
def testRealValuedColumnDeepCopy(self):
@@ -549,9 +545,8 @@ class FeatureColumnTest(test.TestCase):
def testBucketizedColumnRequiresRealValuedColumnDimension(self):
with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn.*"):
- fc.bucketized_column(fc._real_valued_var_len_column("bbb",
- is_sparse=True),
- [0])
+ fc.bucketized_column(
+ fc._real_valued_var_len_column("bbb", is_sparse=True), [0])
def testBucketizedColumnRequiresSortedBuckets(self):
with self.assertRaisesRegexp(ValueError,
@@ -654,20 +649,14 @@ class FeatureColumnTest(test.TestCase):
def testRealValuedColumnDtypes(self):
rvc = fc.real_valued_column("rvc")
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32)
- },
- rvc.config)
+ self.assertDictEqual({
+ "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.float32)
+ }, rvc.config)
rvc = fc.real_valued_column("rvc", dtype=dtypes.int32)
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.int32)
- },
- rvc.config)
+ self.assertDictEqual({
+ "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.int32)
+ }, rvc.config)
with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
@@ -702,8 +691,9 @@ class FeatureColumnTest(test.TestCase):
batch_size = 4
dense_scalar_input = [1, 2, 3, 4]
sparse_column = fc.sparse_column_with_integerized_feature("values", 10)
- features = {"values":
- constant_op.constant(dense_scalar_input, dtype=dtypes.int64)}
+ features = {
+ "values": constant_op.constant(dense_scalar_input, dtype=dtypes.int64)
+ }
sparse_column.insert_transformed_feature(features)
sparse_output = features[sparse_column]
expected_shape = [batch_size, 1]
@@ -731,8 +721,7 @@ class FeatureColumnTest(test.TestCase):
def testSparseColumnKeysDeepCopy(self):
"""Tests deepcopy of sparse_column_with_keys."""
- column = fc.sparse_column_with_keys(
- "a", keys=["key0", "key1", "key2"])
+ column = fc.sparse_column_with_keys("a", keys=["key0", "key1", "key2"])
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
@@ -785,8 +774,9 @@ class FeatureColumnTest(test.TestCase):
a = fc.sparse_column_with_hash_bucket("cross_aaa", hash_bucket_size=100)
b = fc.sparse_column_with_hash_bucket("cross_bbb", hash_bucket_size=100)
cross_col = fc.crossed_column(set([a, b]), hash_bucket_size=10000)
- one_hot_col = fc.one_hot_column(fc.sparse_column_with_hash_bucket(
- "sparse_column_for_one_hot", hash_bucket_size=100))
+ one_hot_col = fc.one_hot_column(
+ fc.sparse_column_with_hash_bucket(
+ "sparse_column_for_one_hot", hash_bucket_size=100))
scattered_embedding_col = fc.scattered_embedding_column(
"scattered_embedding_column", size=100, dimension=10, hash_key=1)
feature_columns = set([
@@ -809,17 +799,13 @@ class FeatureColumnTest(test.TestCase):
"str_id_weights_column":
parsing_ops.VarLenFeature(dtypes.float32),
"real_valued_column1":
- parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([1], dtype=dtypes.float32),
"real_valued_column2":
- parsing_ops.FixedLenFeature(
- [5], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([5], dtype=dtypes.float32),
"real_valued_column_for_bucketization1":
- parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([1], dtype=dtypes.float32),
"real_valued_column_for_bucketization2":
- parsing_ops.FixedLenFeature(
- [4], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([4], dtype=dtypes.float32),
"cross_aaa":
parsing_ops.VarLenFeature(dtypes.string),
"cross_bbb":
@@ -849,11 +835,14 @@ class FeatureColumnTest(test.TestCase):
real_valued_col0 = fc._real_valued_var_len_column(
"real_valued_column0", is_sparse=True)
real_valued_col1 = fc._real_valued_var_len_column(
- "real_valued_column1", dtype=dtypes.int64, default_value=0,
+ "real_valued_column1",
+ dtype=dtypes.int64,
+ default_value=0,
is_sparse=False)
feature_columns = set([real_valued_col0, real_valued_col1])
expected_config = {
- "real_valued_column0": parsing_ops.VarLenFeature(dtype=dtypes.float32),
+ "real_valued_column0":
+ parsing_ops.VarLenFeature(dtype=dtypes.float32),
"real_valued_column1":
parsing_ops.FixedLenSequenceFeature(
[], dtype=dtypes.int64, allow_missing=True, default_value=0),
@@ -874,7 +863,9 @@ class FeatureColumnTest(test.TestCase):
real_valued_col5 = fc._real_valued_var_len_column(
"real_valued_column5", default_value=2, is_sparse=True)
real_valued_col6 = fc._real_valued_var_len_column(
- "real_valued_column6", dtype=dtypes.int64, default_value=1,
+ "real_valued_column6",
+ dtype=dtypes.int64,
+ default_value=1,
is_sparse=False)
feature_columns = [
real_valued_col1, real_valued_col2, real_valued_col3, real_valued_col4,
@@ -902,8 +893,7 @@ class FeatureColumnTest(test.TestCase):
parsing_ops.VarLenFeature(dtype=dtypes.float32),
"real_valued_column6":
parsing_ops.FixedLenSequenceFeature(
- [], dtype=dtypes.int64, allow_missing=True,
- default_value=1)
+ [], dtype=dtypes.int64, allow_missing=True, default_value=1)
},
config)
@@ -1104,8 +1094,8 @@ class FeatureColumnTest(test.TestCase):
# This will initialize the crossed column weights from provided checkpoint
# and return a [4, 1] tensor which is same as weights variable. Since we
# won't modify weights, this should be same as 'saved_col_weights'.
- _, col_weights, _ = (feature_column_ops.weighted_sum_from_feature_columns(
- {
+ _, col_weights, _ = (
+ feature_column_ops.weighted_sum_from_feature_columns({
sparse_col_1.name: input_tensor,
sparse_col_2.name: input_tensor
}, [crossed_col_initialized], 1))
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index f3229a1605..c8e3307ee8 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
@@ -54,47 +55,18 @@ from tensorflow.python.layers.maxout import maxout
# TODO(b/28426988): Replace legacy_* fns migrated from slim.
# TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
-__all__ = ['avg_pool2d',
- 'avg_pool3d',
- 'batch_norm',
- 'bias_add',
- 'conv2d',
- 'conv3d',
- 'conv2d_in_plane',
- 'conv2d_transpose',
- 'conv3d_transpose',
- 'convolution',
- 'convolution2d',
- 'convolution2d_in_plane',
- 'convolution2d_transpose',
- 'convolution3d',
- 'convolution3d_transpose',
- 'dropout',
- 'elu',
- 'flatten',
- 'fully_connected',
- 'GDN',
- 'gdn',
- 'layer_norm',
- 'linear',
- 'pool',
- 'max_pool2d',
- 'max_pool3d',
- 'one_hot_encoding',
- 'relu',
- 'relu6',
- 'repeat',
- 'scale_gradient',
- 'separable_conv2d',
- 'separable_convolution2d',
- 'softmax',
- 'spatial_softmax',
- 'stack',
- 'unit_norm',
- 'legacy_fully_connected',
- 'legacy_linear',
- 'legacy_relu',
- 'maxout']
+__all__ = [
+ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d',
+ 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution',
+ 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose',
+ 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse',
+ 'dropout', 'elu', 'flatten',
+ 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', 'pool',
+ 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat',
+ 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'softmax',
+ 'spatial_softmax', 'stack', 'unit_norm', 'legacy_fully_connected',
+ 'legacy_linear', 'legacy_relu', 'maxout'
+]
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
@@ -139,13 +111,14 @@ def avg_pool2d(inputs,
raise ValueError('data_format has to be either NCHW or NHWC.')
with ops.name_scope(scope, 'AvgPool2D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.AveragePooling2D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.AveragePooling2D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -187,13 +160,14 @@ def avg_pool3d(inputs,
raise ValueError('data_format has to be either NCDHW or NDHWC.')
with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.AveragePooling3D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.AveragePooling3D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -298,8 +272,8 @@ def _fused_batch_norm(inputs,
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif original_rank not in [2, 4]:
raise ValueError('Inputs %s has unsupported rank.'
- ' Expected 2 or 4 but got %d' % (
- inputs.name, original_rank))
+ ' Expected 2 or 4 but got %d' % (inputs.name,
+ original_rank))
if original_rank == 2:
channels = inputs.get_shape()[-1].value
if channels is None:
@@ -393,6 +367,7 @@ def _fused_batch_norm(inputs,
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
+
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
@@ -403,9 +378,9 @@ def _fused_batch_norm(inputs,
epsilon=epsilon,
is_training=False,
data_format=data_format)
- outputs, mean, variance = utils.smart_cond(is_training,
- _fused_batch_norm_training,
- _fused_batch_norm_inference)
+
+ outputs, mean, variance = utils.smart_cond(
+ is_training, _fused_batch_norm_training, _fused_batch_norm_inference)
# If `is_training` doesn't have a constant value, because it is a `Tensor`,
# a `Variable` or `Placeholder` then is_training_value will be None and
@@ -415,6 +390,7 @@ def _fused_batch_norm(inputs,
if need_updates:
if updates_collections is None:
no_updates = lambda: outputs
+
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -424,9 +400,11 @@ def _fused_batch_norm(inputs,
with ops.control_dependencies(
[update_moving_mean, update_moving_variance]):
return array_ops.identity(outputs)
+
outputs = utils.smart_cond(is_training, _force_updates, no_updates)
else:
moving_vars_fn = lambda: (moving_mean, moving_variance)
+
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -434,9 +412,9 @@ def _fused_batch_norm(inputs,
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance
- update_mean, update_variance = utils.smart_cond(is_training,
- _delay_updates,
- moving_vars_fn)
+
+ update_mean, update_variance = utils.smart_cond(
+ is_training, _delay_updates, moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
@@ -479,7 +457,12 @@ def batch_norm(inputs,
Sergey Ioffe, Christian Szegedy
- Can be used as a normalizer function for conv2d and fully_connected.
+ Can be used as a normalizer function for conv2d and fully_connected. The
+ normalization is over all but the last dimension if `data_format` is `NHWC`
+ and all but the second dimension if `data_format` is `NCHW`. In case of a 2D
+ tensor this corresponds to the batch dimension, while in case of a 4D tensor
+ this
+ corresponds to the batch and space dimensions.
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
@@ -588,10 +571,9 @@ def batch_norm(inputs,
# implementation in normalization_layers.BatchNormalization.
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims
- possible_to_fuse = (batch_weights is None and
- not renorm and
- rank in [2, 4] and
- adjustment is None)
+ possible_to_fuse = (
+ batch_weights is None and not renorm and rank in [2, 4] and
+ adjustment is None)
if fused and possible_to_fuse and (
zero_debias_moving_mean or rank == 2 or
updates_collections is not ops.GraphKeys.UPDATE_OPS):
@@ -619,7 +601,9 @@ def batch_norm(inputs,
layer_variable_getter = _build_variable_getter()
with variable_scope.variable_scope(
- scope, 'BatchNorm', [inputs], reuse=reuse,
+ scope,
+ 'BatchNorm', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
@@ -667,15 +651,15 @@ def batch_norm(inputs,
outputs = layer.apply(inputs, training=is_training)
# Add variables to collections.
- _add_variable_to_collections(
- layer.moving_mean, variables_collections, 'moving_mean')
- _add_variable_to_collections(
- layer.moving_variance, variables_collections, 'moving_variance')
+ _add_variable_to_collections(layer.moving_mean, variables_collections,
+ 'moving_mean')
+ _add_variable_to_collections(layer.moving_variance, variables_collections,
+ 'moving_variance')
if layer.beta is not None:
_add_variable_to_collections(layer.beta, variables_collections, 'beta')
if layer.gamma is not None:
- _add_variable_to_collections(
- layer.gamma, variables_collections, 'gamma')
+ _add_variable_to_collections(layer.gamma, variables_collections,
+ 'gamma')
if activation_fn is not None:
outputs = activation_fn(outputs)
@@ -715,8 +699,8 @@ def batch_norm(inputs,
params_shape = inputs_shape[-1:]
params_shape_broadcast = None
if not params_shape.is_fully_defined():
- raise ValueError('Inputs %s has undefined channels dimension %s.' % (
- inputs.name, params_shape))
+ raise ValueError('Inputs %s has undefined channels dimension %s.' %
+ (inputs.name, params_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
@@ -727,23 +711,25 @@ def batch_norm(inputs,
'beta')
beta_initializer = param_initializers.get('beta',
init_ops.zeros_initializer())
- beta = variables.model_variable('beta',
- shape=params_shape,
- dtype=dtype,
- initializer=beta_initializer,
- collections=beta_collections,
- trainable=trainable)
+ beta = variables.model_variable(
+ 'beta',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=beta_initializer,
+ collections=beta_collections,
+ trainable=trainable)
if scale:
- gamma_collections = utils.get_variable_collections(variables_collections,
- 'gamma')
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
gamma_initializer = param_initializers.get('gamma',
init_ops.ones_initializer())
- gamma = variables.model_variable('gamma',
- shape=params_shape,
- dtype=dtype,
- initializer=gamma_initializer,
- collections=gamma_collections,
- trainable=trainable)
+ gamma = variables.model_variable(
+ 'gamma',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=gamma_initializer,
+ collections=gamma_collections,
+ trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
# appropriate collections. We disable variable partitioning while creating
@@ -792,8 +778,8 @@ def batch_norm(inputs,
mean, variance = nn.moments(inputs, moments_axes)
else:
if data_format == DATA_FORMAT_NCHW:
- mean, variance = nn.weighted_moments(inputs, moments_axes,
- batch_weights, keep_dims=True)
+ mean, variance = nn.weighted_moments(
+ inputs, moments_axes, batch_weights, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
@@ -802,19 +788,21 @@ def batch_norm(inputs,
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
+
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay, zero_debias=False)
- with ops.control_dependencies([update_moving_mean,
- update_moving_variance]):
+ with ops.control_dependencies(
+ [update_moving_mean, update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
- mean, variance = utils.smart_cond(is_training,
- _force_updates,
+
+ mean, variance = utils.smart_cond(is_training, _force_updates,
moving_vars_fn)
else:
+
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -823,9 +811,8 @@ def batch_norm(inputs,
moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance
- update_mean, update_variance = utils.smart_cond(is_training,
- _delay_updates,
- moving_vars_fn)
+ update_mean, update_variance = utils.smart_cond(
+ is_training, _delay_updates, moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)
# Use computed moments during training and moving_vars otherwise.
@@ -893,8 +880,8 @@ def bias_add(inputs,
"""
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
- with variable_scope.variable_scope(scope, 'BiasAdd', [inputs],
- reuse=reuse) as sc:
+ with variable_scope.variable_scope(
+ scope, 'BiasAdd', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
dtype = inputs.dtype.base_dtype
inputs_shape = inputs.get_shape()
@@ -909,13 +896,16 @@ def bias_add(inputs,
raise ValueError('`C` dimension must be known but is None')
biases_collections = utils.get_variable_collections(variables_collections,
'biases')
- biases = variables.model_variable('biases',
- shape=[num_features,],
- dtype=dtype,
- initializer=initializer,
- regularizer=regularizer,
- collections=biases_collections,
- trainable=trainable)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_features,
+ ],
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ collections=biases_collections,
+ trainable=trainable)
outputs = nn.bias_add(inputs, biases, data_format=data_format)
if activation_fn is not None:
outputs = activation_fn(outputs)
@@ -1015,8 +1005,10 @@ def convolution(inputs,
if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('Invalid data_format: %r' % (data_format,))
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
scope, 'Conv', [inputs], reuse=reuse,
@@ -1034,26 +1026,27 @@ def convolution(inputs,
raise ValueError('Convolution not supported for input with rank',
input_rank)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = layer_class(filters=num_outputs,
- kernel_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- dilation_rate=rate,
- activation=None,
- use_bias=not normalizer_fn and biases_initializer,
- kernel_initializer=weights_initializer,
- bias_initializer=biases_initializer,
- kernel_regularizer=weights_regularizer,
- bias_regularizer=biases_regularizer,
- activity_regularizer=None,
- trainable=trainable,
- name=sc.name,
- dtype=inputs.dtype.base_dtype,
- _scope=sc,
- _reuse=reuse)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = layer_class(
+ filters=num_outputs,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ dilation_rate=rate,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
outputs = layer.apply(inputs)
# Add variables to collections.
@@ -1069,6 +1062,7 @@ def convolution(inputs,
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
convolution2d = convolution
convolution3d = convolution
@@ -1144,13 +1138,14 @@ def convolution2d_in_plane(
weights_shape = [kernel_h, kernel_w, 1, 1]
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
- weights = variables.model_variable('weights',
- shape=weights_shape,
- dtype=dtype,
- initializer=weights_initializer,
- regularizer=weights_regularizer,
- collections=weights_collections,
- trainable=trainable)
+ weights = variables.model_variable(
+ 'weights',
+ shape=weights_shape,
+ dtype=dtype,
+ initializer=weights_initializer,
+ regularizer=weights_regularizer,
+ collections=weights_collections,
+ trainable=trainable)
depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1])
outputs = nn.depthwise_conv2d(inputs, depthwise_weights,
[1, stride_h, stride_w, 1], padding)
@@ -1161,13 +1156,16 @@ def convolution2d_in_plane(
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
- biases = variables.model_variable('biases',
- shape=[num_filters_in,],
- dtype=dtype,
- initializer=biases_initializer,
- regularizer=biases_regularizer,
- collections=biases_collections,
- trainable=trainable)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_filters_in,
+ ],
+ dtype=dtype,
+ initializer=biases_initializer,
+ regularizer=biases_regularizer,
+ collections=biases_collections,
+ trainable=trainable)
outputs = nn.bias_add(outputs, biases)
if activation_fn is not None:
@@ -1240,19 +1238,23 @@ def convolution2d_transpose(
ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
ValueError: If `C` dimension of `inputs` is None.
"""
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'Conv2d_transpose', [inputs], reuse=reuse,
+ scope,
+ 'Conv2d_transpose', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
layer = convolutional_layers.Convolution2DTranspose(
filters=num_outputs,
kernel_size=kernel_size,
@@ -1349,19 +1351,23 @@ def convolution3d_transpose(
ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
ValueError: If `C` dimension of `inputs` is None.
"""
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases', 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'Conv3d_transpose', [inputs], reuse=reuse,
+ scope,
+ 'Conv3d_transpose', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
raise ValueError('data_format has to be either NCDHW or NDHWC.')
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
layer = convolutional_layers.Convolution3DTranspose(
filters=num_outputs,
kernel_size=kernel_size,
@@ -1397,6 +1403,29 @@ def convolution3d_transpose(
@add_arg_scope
+def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None):
+ """Converts a dense tensor into a sparse tensor.
+ An example use would be to convert dense labels to sparse ones
+ so that they can be fed to the ctc_loss.
+
+ Args:
+ tensor: An `int` `Tensor` to be converted to a `Sparse`.
+ eos_token: An integer.
+ It is part of the target label that signfies the end of a sentence.
+ outputs_collections: Collection to add the outputs.
+ scope: Optional scope for name_scope.
+ """
+ with variable_scope.variable_scope(
+ scope, 'dense_to_sparse', [tensor]) as sc:
+ tensor = ops.convert_to_tensor(tensor)
+ indices = array_ops.where(math_ops.not_equal(tensor, constant_op.constant(eos_token, tensor.dtype)))
+ values = array_ops.gather_nd(tensor, indices)
+ shape = array_ops.shape(tensor, out_type=dtypes.int64)
+ outputs = sparse_tensor.SparseTensor(indices, values, shape)
+ return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
+
+
+@add_arg_scope
def dropout(inputs,
keep_prob=0.5,
noise_shape=None,
@@ -1430,19 +1459,18 @@ def dropout(inputs,
with variable_scope.variable_scope(
scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
- layer = core_layers.Dropout(rate=1 - keep_prob,
- noise_shape=noise_shape,
- seed=seed,
- name=sc.name,
- _scope=sc)
+ layer = core_layers.Dropout(
+ rate=1 - keep_prob,
+ noise_shape=noise_shape,
+ seed=seed,
+ name=sc.name,
+ _scope=sc)
outputs = layer.apply(inputs, training=is_training)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
@add_arg_scope
-def flatten(inputs,
- outputs_collections=None,
- scope=None):
+def flatten(inputs, outputs_collections=None, scope=None):
"""Flattens the input while maintaining the batch_size.
Assumes that the first dimension represents the batch.
@@ -1474,8 +1502,8 @@ def _sparse_inner_flatten(inputs, new_rank):
outer_dimensions = inputs.dense_shape[:new_rank - 1]
inner_dimensions = inputs.dense_shape[new_rank - 1:]
- new_shape = array_ops.concat((outer_dimensions,
- [math_ops.reduce_prod(inner_dimensions)]), 0)
+ new_shape = array_ops.concat(
+ (outer_dimensions, [math_ops.reduce_prod(inner_dimensions)]), 0)
flattened = sparse_ops.sparse_reshape(inputs, new_shape)
return flattened
@@ -1541,10 +1569,18 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
return utils.collect_named_outputs(output_collections, sc, flattened)
-def _model_variable_getter(getter, name, shape=None, dtype=None,
- initializer=None, regularizer=None, trainable=True,
- collections=None, caching_device=None,
- partitioner=None, rename=None, use_resource=None,
+def _model_variable_getter(getter,
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ rename=None,
+ use_resource=None,
**_):
"""Getter that uses model_variable for compatibility with core layers."""
short_name = name.split('/')[-1]
@@ -1553,25 +1589,34 @@ def _model_variable_getter(getter, name, shape=None, dtype=None,
name_components[-1] = rename[short_name]
name = '/'.join(name_components)
return variables.model_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, collections=collections, trainable=trainable,
- caching_device=caching_device, partitioner=partitioner,
- custom_getter=getter, use_resource=use_resource)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ collections=collections,
+ trainable=trainable,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ custom_getter=getter,
+ use_resource=use_resource)
def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""
+
# VariableScope will nest the getters
def layer_variable_getter(getter, *args, **kwargs):
kwargs['rename'] = rename
return _model_variable_getter(getter, *args, **kwargs)
+
return layer_variable_getter
def _add_variable_to_collections(variable, collections_set, collections_name):
"""Adds variable (or all its parts) to all collections with that name."""
- collections = utils.get_variable_collections(
- collections_set, collections_name) or []
+ collections = utils.get_variable_collections(collections_set,
+ collections_name) or []
variables_list = [variable]
if isinstance(variable, tf_variables.PartitionedVariable):
variables_list = [v for v in variable]
@@ -1640,15 +1685,19 @@ def fully_connected(inputs,
ValueError: If x has rank less than 2 or if its last dimension is not set.
"""
if not isinstance(num_outputs, six.integer_types):
- raise ValueError(
- 'num_outputs should be int or long, got %s.' % (num_outputs,))
+ raise ValueError('num_outputs should be int or long, got %s.' %
+ (num_outputs,))
- layer_variable_getter = _build_variable_getter({'bias': 'biases',
- 'kernel': 'weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'kernel': 'weights'
+ })
with variable_scope.variable_scope(
- scope, 'fully_connected', [inputs],
- reuse=reuse, custom_getter=layer_variable_getter) as sc:
+ scope,
+ 'fully_connected', [inputs],
+ reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
layer = core_layers.Dense(
units=num_outputs,
@@ -1754,15 +1803,17 @@ class GDN(base.Layer):
inverse=False,
beta_min=1e-6,
gamma_init=.1,
- reparam_offset=2 ** -18,
+ reparam_offset=2**-18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
- super(GDN, self).__init__(trainable=trainable, name=name,
- activity_regularizer=activity_regularizer,
- **kwargs)
+ super(GDN, self).__init__(
+ trainable=trainable,
+ name=name,
+ activity_regularizer=activity_regularizer,
+ **kwargs)
self.inverse = inverse
self._beta_min = beta_min
self._gamma_init = gamma_init
@@ -1797,8 +1848,9 @@ class GDN(base.Layer):
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
inputs = ops.convert_to_tensor(inputs, name='inputs')
bound = ops.convert_to_tensor(bound, name='bound')
- with ops.get_default_graph().gradient_override_map(
- {'Maximum': 'GDNLowerBound'}):
+ with ops.get_default_graph().gradient_override_map({
+ 'Maximum': 'GDNLowerBound'
+ }):
return math_ops.maximum(inputs, bound, name=scope)
@staticmethod
@@ -1825,12 +1877,14 @@ class GDN(base.Layer):
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
- self.input_spec = base.InputSpec(ndim=input_shape.ndims,
- axes={channel_axis: num_channels})
+ self.input_spec = base.InputSpec(
+ ndim=input_shape.ndims, axes={
+ channel_axis: num_channels
+ })
- pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
+ pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype)
beta_bound = array_ops.constant(
- (self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
+ (self._beta_min + self._reparam_offset**2)**.5, dtype=self.dtype)
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
def beta_initializer(shape, dtype=None, partition_info=None):
@@ -1844,19 +1898,21 @@ class GDN(base.Layer):
eye = linalg_ops.eye(shape[0], dtype=dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
- beta = self.add_variable('reparam_beta',
- shape=[num_channels],
- initializer=beta_initializer,
- dtype=self.dtype,
- trainable=True)
+ beta = self.add_variable(
+ 'reparam_beta',
+ shape=[num_channels],
+ initializer=beta_initializer,
+ dtype=self.dtype,
+ trainable=True)
beta = self._lower_bound(beta, beta_bound)
self.beta = math_ops.square(beta) - pedestal
- gamma = self.add_variable('reparam_gamma',
- shape=[num_channels, num_channels],
- initializer=gamma_initializer,
- dtype=self.dtype,
- trainable=True)
+ gamma = self.add_variable(
+ 'reparam_gamma',
+ shape=[num_channels, num_channels],
+ initializer=gamma_initializer,
+ dtype=self.dtype,
+ trainable=True)
gamma = self._lower_bound(gamma, gamma_bound)
self.gamma = math_ops.square(gamma) - pedestal
@@ -1871,8 +1927,11 @@ class GDN(base.Layer):
# Compute normalization pool.
if self.data_format == 'channels_first':
- norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
- data_format='NC' + 'DHW'[-(ndim - 2):])
+ norm_pool = nn.convolution(
+ math_ops.square(inputs),
+ gamma,
+ 'VALID',
+ data_format='NC' + 'DHW' [-(ndim - 2):])
if ndim == 3:
norm_pool = array_ops.expand_dims(norm_pool, 2)
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
@@ -1914,7 +1973,7 @@ def gdn(inputs,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
- reparam_offset=2 ** -18,
+ reparam_offset=2**-18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
@@ -1980,17 +2039,18 @@ def gdn(inputs,
Returns:
Output tensor.
"""
- layer = GDN(inverse=inverse,
- beta_min=beta_min,
- gamma_init=gamma_init,
- reparam_offset=reparam_offset,
- data_format=data_format,
- activity_regularizer=activity_regularizer,
- trainable=trainable,
- name=name,
- dtype=inputs.dtype.base_dtype,
- _scope=name,
- _reuse=reuse)
+ layer = GDN(
+ inverse=inverse,
+ beta_min=beta_min,
+ gamma_init=gamma_init,
+ reparam_offset=reparam_offset,
+ data_format=data_format,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ name=name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=name,
+ _reuse=reuse)
return layer.apply(inputs)
@@ -2066,8 +2126,8 @@ def layer_norm(inputs,
or if `inputs.shape[begin_params_axis:]` is not fully defined at
graph build time.
"""
- with variable_scope.variable_scope(scope, 'LayerNorm', [inputs],
- reuse=reuse) as sc:
+ with variable_scope.variable_scope(
+ scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.shape
inputs_rank = inputs_shape.ndims
@@ -2077,15 +2137,14 @@ def layer_norm(inputs,
if begin_norm_axis < 0:
begin_norm_axis = inputs_rank + begin_norm_axis
if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
- raise ValueError(
- 'begin_params_axis (%d) and begin_norm_axis (%d) '
- 'must be < rank(inputs) (%d)'
- % (begin_params_axis, begin_norm_axis, inputs_rank))
+ raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) '
+ 'must be < rank(inputs) (%d)' %
+ (begin_params_axis, begin_norm_axis, inputs_rank))
params_shape = inputs_shape[begin_params_axis:]
if not params_shape.is_fully_defined():
raise ValueError(
- 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (
- inputs.name, begin_params_axis, inputs_shape))
+ 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' %
+ (inputs.name, begin_params_axis, inputs_shape))
# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
@@ -2099,8 +2158,8 @@ def layer_norm(inputs,
collections=beta_collections,
trainable=trainable)
if scale:
- gamma_collections = utils.get_variable_collections(variables_collections,
- 'gamma')
+ gamma_collections = utils.get_variable_collections(
+ variables_collections, 'gamma')
gamma = variables.model_variable(
'gamma',
shape=params_shape,
@@ -2114,7 +2173,11 @@ def layer_norm(inputs,
# Compute layer normalization using the batch_normalization function.
variance_epsilon = 1e-12
outputs = nn.batch_normalization(
- inputs, mean, variance, offset=beta, scale=gamma,
+ inputs,
+ mean,
+ variance,
+ offset=beta,
+ scale=gamma,
variance_epsilon=variance_epsilon)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
@@ -2160,13 +2223,14 @@ def max_pool2d(inputs,
raise ValueError('data_format has to be either NCHW or NHWC.')
with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.MaxPooling2D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.MaxPooling2D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -2209,13 +2273,14 @@ def max_pool3d(inputs,
raise ValueError('data_format has to be either NCDHW or NDHWC.')
with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
- layer = pooling_layers.MaxPooling3D(pool_size=kernel_size,
- strides=stride,
- padding=padding,
- data_format=df,
- _scope=sc)
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
+ layer = pooling_layers.MaxPooling3D(
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
outputs = layer.apply(inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
@@ -2268,8 +2333,8 @@ def pool(inputs,
"""
# pylint: enable=line-too-long
- with ops.name_scope(scope, '%s_pool' %
- (pooling_type.lower()), [inputs]) as sc:
+ with ops.name_scope(scope, '%s_pool' % (pooling_type.lower()),
+ [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
input_rank = inputs.get_shape().ndims
if input_rank is None:
@@ -2314,18 +2379,16 @@ def one_hot_encoding(labels,
labels = ops.convert_to_tensor(labels)
if labels.dtype == dtypes.int32:
labels = standard_ops.to_int64(labels)
- outputs = standard_ops.one_hot(labels,
- num_classes,
- on_value=on_value,
- off_value=off_value)
+ outputs = standard_ops.one_hot(
+ labels, num_classes, on_value=on_value, off_value=off_value)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
def _apply_activation(y, activation_fn, output_collections):
if activation_fn is not None:
y = activation_fn(y)
- ops.add_to_collections(list(output_collections or []) +
- [ops.GraphKeys.ACTIVATIONS], y)
+ ops.add_to_collections(
+ list(output_collections or []) + [ops.GraphKeys.ACTIVATIONS], y)
return y
@@ -2370,7 +2433,7 @@ def repeat(inputs, repetitions, layer, *args, **kwargs):
scope = 'repeat'
outputs = inputs
for i in range(repetitions):
- kwargs['scope'] = scope + '_' + str(i+1)
+ kwargs['scope'] = scope + '_' + str(i + 1)
outputs = layer(outputs, *args, **kwargs)
return outputs
@@ -2385,8 +2448,8 @@ def _scale_gradient_grad(op, grad):
return [grad * op.inputs[1], None]
-@function.Defun(python_grad_func=_scale_gradient_grad,
- shape_func=_scale_gradient_shape)
+@function.Defun(
+ python_grad_func=_scale_gradient_grad, shape_func=_scale_gradient_shape)
def scale_gradient(inputs, gradient_multiplier):
"""Identity operation, but with the gradient multiplied by a tensor.
@@ -2491,18 +2554,21 @@ def separable_convolution2d(
"""
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')
- layer_variable_getter = _build_variable_getter(
- {'bias': 'biases',
- 'depthwise_kernel': 'depthwise_weights',
- 'pointwise_kernel': 'pointwise_weights'})
+ layer_variable_getter = _build_variable_getter({
+ 'bias': 'biases',
+ 'depthwise_kernel': 'depthwise_weights',
+ 'pointwise_kernel': 'pointwise_weights'
+ })
with variable_scope.variable_scope(
- scope, 'SeparableConv2d', [inputs], reuse=reuse,
+ scope,
+ 'SeparableConv2d', [inputs],
+ reuse=reuse,
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
- df = ('channels_first' if data_format and data_format.startswith('NC')
- else 'channels_last')
+ df = ('channels_first'
+ if data_format and data_format.startswith('NC') else 'channels_last')
if num_outputs is not None:
# Apply separable conv using the SeparableConvolution2D layer.
layer = convolutional_layers.SeparableConvolution2D(
@@ -2535,8 +2601,8 @@ def separable_convolution2d(
_add_variable_to_collections(layer.pointwise_kernel,
variables_collections, 'weights')
if layer.bias is not None:
- _add_variable_to_collections(layer.bias,
- variables_collections, 'biases')
+ _add_variable_to_collections(layer.bias, variables_collections,
+ 'biases')
if normalizer_fn is not None:
normalizer_params = normalizer_params or {}
@@ -2551,8 +2617,7 @@ def separable_convolution2d(
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
- depthwise_shape = [kernel_h, kernel_w,
- num_filters_in, depth_multiplier]
+ depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier]
depthwise_weights = variables.model_variable(
'depthwise_weights',
shape=depthwise_shape,
@@ -2566,9 +2631,13 @@ def separable_convolution2d(
1, stride_h, stride_w, 1
]
- outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding,
- rate=utils.two_element_tuple(rate),
- data_format=data_format)
+ outputs = nn.depthwise_conv2d(
+ inputs,
+ depthwise_weights,
+ strides,
+ padding,
+ rate=utils.two_element_tuple(rate),
+ data_format=data_format)
num_outputs = depth_multiplier * num_filters_in
if normalizer_fn is not None:
@@ -2578,13 +2647,16 @@ def separable_convolution2d(
if biases_initializer is not None:
biases_collections = utils.get_variable_collections(
variables_collections, 'biases')
- biases = variables.model_variable('biases',
- shape=[num_outputs,],
- dtype=dtype,
- initializer=biases_initializer,
- regularizer=biases_regularizer,
- trainable=trainable,
- collections=biases_collections)
+ biases = variables.model_variable(
+ 'biases',
+ shape=[
+ num_outputs,
+ ],
+ dtype=dtype,
+ initializer=biases_initializer,
+ regularizer=biases_regularizer,
+ trainable=trainable,
+ collections=biases_collections)
outputs = nn.bias_add(outputs, biases, data_format=data_format)
if activation_fn is not None:
@@ -2669,23 +2741,24 @@ def spatial_softmax(features,
with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]):
# Create tensors for x and y coordinate values, scaled to range [-1, 1].
- pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height),
- math_ops.lin_space(-1., 1., num=width),
- indexing='ij')
+ pos_x, pos_y = array_ops.meshgrid(
+ math_ops.lin_space(-1., 1., num=height),
+ math_ops.lin_space(-1., 1., num=width),
+ indexing='ij')
pos_x = array_ops.reshape(pos_x, [height * width])
pos_y = array_ops.reshape(pos_y, [height * width])
-
+
if temperature is None:
temp_initializer = init_ops.ones_initializer()
else:
temp_initializer = init_ops.constant_initializer(temperature)
-
+
if not trainable:
temp_collections = None
else:
temp_collections = utils.get_variable_collections(
- variables_collections, 'temperature')
-
+ variables_collections, 'temperature')
+
temperature = variables.model_variable(
'temperature',
shape=(),
@@ -2699,14 +2772,14 @@ def spatial_softmax(features,
features = array_ops.reshape(
array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width])
- softmax_attention = nn.softmax(features/temperature)
+ softmax_attention = nn.softmax(features / temperature)
expected_x = math_ops.reduce_sum(
pos_x * softmax_attention, [1], keep_dims=True)
expected_y = math_ops.reduce_sum(
pos_y * softmax_attention, [1], keep_dims=True)
expected_xy = array_ops.concat([expected_x, expected_y], 1)
- feature_keypoints = array_ops.reshape(
- expected_xy, [-1, num_channels.value * 2])
+ feature_keypoints = array_ops.reshape(expected_xy,
+ [-1, num_channels.value * 2])
feature_keypoints.set_shape([None, num_channels.value * 2])
return feature_keypoints
@@ -2758,7 +2831,7 @@ def stack(inputs, layer, stack_args, **kwargs):
scope = 'stack'
outputs = inputs
for i in range(len(stack_args)):
- kwargs['scope'] = scope + '_' + str(i+1)
+ kwargs['scope'] = scope + '_' + str(i + 1)
layer_args = stack_args[i]
if not isinstance(layer_args, (list, tuple)):
layer_args = [layer_args]
@@ -2789,11 +2862,10 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None):
raise ValueError('The input rank must be known.')
input_rank = len(inputs.get_shape().as_list())
if dim < 0 or dim >= input_rank:
- raise ValueError(
- 'dim must be positive but smaller than the input rank.')
+ raise ValueError('dim must be positive but smaller than the input rank.')
- lengths = math_ops.sqrt(epsilon + math_ops.reduce_sum(
- math_ops.square(inputs), dim, True))
+ lengths = math_ops.sqrt(
+ epsilon + math_ops.reduce_sum(math_ops.square(inputs), dim, True))
multiples = []
if dim > 0:
multiples.append(array_ops.ones([dim], dtypes.int32))
@@ -2934,29 +3006,31 @@ def legacy_fully_connected(x,
raise ValueError('last dimension of x must be known but is None')
dtype = x.dtype.base_dtype
- weight_collections = set(list(weight_collections or []) +
- [ops.GraphKeys.GLOBAL_VARIABLES])
- w = variable_scope.get_variable('weights',
- shape=[num_input_units, num_output_units],
- dtype=dtype,
- initializer=weight_init,
- collections=weight_collections,
- regularizer=weight_regularizer,
- trainable=trainable)
- x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x,
- [-1, num_input_units])
+ weight_collections = set(
+ list(weight_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+ w = variable_scope.get_variable(
+ 'weights',
+ shape=[num_input_units, num_output_units],
+ dtype=dtype,
+ initializer=weight_init,
+ collections=weight_collections,
+ regularizer=weight_regularizer,
+ trainable=trainable)
+ x_2_dim = x if len(dims) <= 2 else array_ops.reshape(
+ x, [-1, num_input_units])
y = standard_ops.matmul(x_2_dim, w)
if bias_init is not None:
- bias_collections = set(list(bias_collections or []) +
- [ops.GraphKeys.GLOBAL_VARIABLES])
- b = variable_scope.get_variable('bias',
- shape=[num_output_units],
- dtype=dtype,
- initializer=bias_init,
- collections=bias_collections,
- regularizer=bias_regularizer,
- trainable=trainable)
+ bias_collections = set(
+ list(bias_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES])
+ b = variable_scope.get_variable(
+ 'bias',
+ shape=[num_output_units],
+ dtype=dtype,
+ initializer=bias_init,
+ collections=bias_collections,
+ regularizer=bias_regularizer,
+ trainable=trainable)
y = nn.bias_add(y, b)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index a9bdbe0138..c5790c7622 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
@@ -1292,6 +1293,17 @@ class ConvolutionInPlaneTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
+class DenseToSparseTest(test.TestCase):
+
+ def testDenseFromConstantToSparse(self):
+ expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2))
+ tensor = constant_op.constant(expected_constant)
+ sparse = _layers.dense_to_sparse(tensor)
+ dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, sparse.values)
+ with self.test_session() as sess:
+ constant = sess.run(dense)
+ self.assertAllEqual(expected_constant, constant)
+
class DropoutTest(test.TestCase):
def testCreateDropout(self):
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index ee3611ca93..3c782b54a8 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -494,7 +494,7 @@ py_test(
name = "linear_test",
size = "medium",
srcs = ["python/learn/estimators/linear_test.py"],
- shard_count = 4,
+ shard_count = 20,
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
index a3521b4109..7240b0de14 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Dataset utilities and synthetic/reference datasets."""
from __future__ import absolute_import
@@ -46,11 +45,12 @@ DATASETS = {
# List of all synthetic datasets
SYNTHETIC = {
- # All of these will return ['data', 'target'] -> base.Dataset
- 'circles': synthetic.circles,
- 'spirals': synthetic.spirals
+ # All of these will return ['data', 'target'] -> base.Dataset
+ 'circles': synthetic.circles,
+ 'spirals': synthetic.spirals
}
+
def load_dataset(name, size='small', test_with_fake_data=False):
"""Loads dataset by name.
@@ -83,23 +83,28 @@ def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs):
seed: int or None, seed for noise
Returns:
- Shuffled features and labels for given synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for given synthetic dataset of type
+ `base.Dataset`
Raises:
ValueError: Raised if `name` not found
Note:
- - This is a generic synthetic data generator - individual generators might have more parameters!
+ - This is a generic synthetic data generator - individual generators might
+ have more parameters!
See documentation for individual parameters
- - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed
+ - Note that the `noise` parameter uses `numpy.random.normal` and depends on
+ `numpy`'s seed
TODO:
- Support multiclass datasets
- - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation,
+ - Need shuffling routine. Currently synthetic datasets are reshuffled to
+ avoid train/test correlation,
but that hurts reprodusability
"""
# seed = kwargs.pop('seed', None)
if name not in SYNTHETIC:
raise ValueError('Synthetic dataset not found or not implemeted: %s' % name)
else:
- return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)
+ return SYNTHETIC[name](
+ n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)
diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
index 907dc0f3df..649996c49c 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Synthetic dataset generators."""
from __future__ import absolute_import
@@ -23,18 +22,27 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets.base import Dataset
-def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args, **kwargs):
+
+def circles(n_samples=100,
+ noise=None,
+ seed=None,
+ factor=0.8,
+ n_classes=2,
+ *args,
+ **kwargs):
"""Create circles separated by some value
Args:
n_samples: int, number of datapoints to generate
noise: float or None, standard deviation of the Gaussian noise added
seed: int or None, seed for the noise
- factor: float, size factor of the inner circles with respect to the outer ones
+ factor: float, size factor of the inner circles with respect to the outer
+ ones
n_classes: int, number of classes to generate
Returns:
- Shuffled features and labels for 'circles' synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for 'circles' synthetic dataset of type
+ `base.Dataset`
Note:
The multi-class support might not work as expected if `noise` is enabled
@@ -54,7 +62,7 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
if seed is not None:
np.random.seed(seed)
# Algo: 1) Generate initial circle, 2) For ever class generate a smaller radius circle
- linspace = np.linspace(0, 2*np.pi, n_samples // n_classes)
+ linspace = np.linspace(0, 2 * np.pi, n_samples // n_classes)
circ_x = np.empty(0, dtype=np.int32)
circ_y = np.empty(0, dtype=np.int32)
base_cos = np.cos(linspace)
@@ -66,12 +74,12 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
circ_y = np.append(circ_y, base_sin)
base_cos *= factor
base_sin *= factor
- y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32))
+ y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32))
# Add more points if n_samples is not divisible by n_classes (unbalanced!)
extras = n_samples % n_classes
- circ_x = np.append(circ_x, np.cos(np.random.rand(extras)*2*np.pi))
- circ_y = np.append(circ_y, np.sin(np.random.rand(extras)*2*np.pi))
+ circ_x = np.append(circ_x, np.cos(np.random.rand(extras) * 2 * np.pi))
+ circ_y = np.append(circ_y, np.sin(np.random.rand(extras) * 2 * np.pi))
y = np.append(y, np.zeros(extras, dtype=np.int32))
# Reshape the features/labels
@@ -85,10 +93,13 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
return Dataset(data=X[indices], target=y[indices])
-def spirals(n_samples=100, noise=None, seed=None,
- mode = 'archimedes',
- n_loops = 2,
- *args, **kwargs):
+def spirals(n_samples=100,
+ noise=None,
+ seed=None,
+ mode='archimedes',
+ n_loops=2,
+ *args,
+ **kwargs):
"""Create spirals
Currently only binary classification is supported for spiral generation
@@ -104,7 +115,8 @@ def spirals(n_samples=100, noise=None, seed=None,
'fermat': a spiral with branch distances decreasing (sqrt)
Returns:
- Shuffled features and labels for 'spirals' synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for 'spirals' synthetic dataset of type
+ `base.Dataset`
Raises:
ValueError: If the generation `mode` is not valid
@@ -112,34 +124,35 @@ def spirals(n_samples=100, noise=None, seed=None,
TODO:
- Generation of unbalanced data
"""
- n_classes = 2 # I am not sure how to make it multiclass
+ n_classes = 2 # I am not sure how to make it multiclass
_modes = {
- 'archimedes': _archimedes_spiral,
- 'bernoulli': _bernoulli_spiral,
- 'fermat': _fermat_spiral
+ 'archimedes': _archimedes_spiral,
+ 'bernoulli': _bernoulli_spiral,
+ 'fermat': _fermat_spiral
}
if mode is None or mode not in _modes:
- raise ValueError("Cannot generate spiral with mode %s"%mode)
+ raise ValueError('Cannot generate spiral with mode %s' % mode)
if seed is not None:
np.random.seed(seed)
- linspace = np.linspace(0, 2*n_loops*np.pi, n_samples // n_classes)
+ linspace = np.linspace(0, 2 * n_loops * np.pi, n_samples // n_classes)
spir_x = np.empty(0, dtype=np.int32)
spir_y = np.empty(0, dtype=np.int32)
y = np.empty(0, dtype=np.int32)
for label in range(n_classes):
- base_cos, base_sin = _modes[mode](linspace, label*np.pi, *args, **kwargs)
+ base_cos, base_sin = _modes[mode](linspace, label * np.pi, *args, **kwargs)
spir_x = np.append(spir_x, base_cos)
spir_y = np.append(spir_y, base_sin)
- y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32))
+ y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32))
# Add more points if n_samples is not divisible by n_classes (unbalanced!)
extras = n_samples % n_classes
if extras > 0:
- x_exrta, y_extra = _modes[mode](np.random.rand(extras)*2*np.pi, *args, **kwargs)
+ x_exrta, y_extra = _modes[mode](np.random.rand(extras) * 2 * np.pi, *args,
+ **kwargs)
spir_x = np.append(spir_x, x_extra)
spir_y = np.append(spir_y, y_extra)
y = np.append(y, np.zeros(extras, dtype=np.int32))
@@ -162,7 +175,8 @@ def _archimedes_spiral(theta, theta_offset=0., *args, **kwargs):
theta: array-like, angles from polar coordinates to be converted
theta_offset: float, angle offset in radians (2*pi = 0)
"""
- x, y = theta*np.cos(theta + theta_offset), theta*np.sin(theta + theta_offset)
+ x, y = theta * np.cos(theta + theta_offset), theta * np.sin(
+ theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
@@ -181,7 +195,8 @@ def _bernoulli_spiral(theta, theta_offset=0., *args, **kwargs):
"""
exp_scale = kwargs.pop('exp_scale', 0.1)
- x, y = np.exp(exp_scale*theta)*np.cos(theta + theta_offset), np.exp(exp_scale*theta)*np.sin(theta + theta_offset)
+ x, y = np.exp(exp_scale * theta) * np.cos(theta + theta_offset), np.exp(
+ exp_scale * theta) * np.sin(theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
@@ -195,7 +210,8 @@ def _fermat_spiral(theta, theta_offset=0., *args, **kwargs):
theta: array-like, angles from polar coordinates to be converted
theta_offset: float, angle offset in radians (2*pi = 0)
"""
- x, y = np.sqrt(theta)*np.cos(theta + theta_offset), np.sqrt(theta)*np.sin(theta + theta_offset)
+ x, y = np.sqrt(theta) * np.cos(theta + theta_offset), np.sqrt(theta) * np.sin(
+ theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 50c74add86..8d59fe66d9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Base Estimator class."""
from __future__ import absolute_import
@@ -76,7 +75,6 @@ from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
-
AS_ITERABLE_DATE = '2016-09-15'
AS_ITERABLE_INSTRUCTIONS = (
'The default behavior of predict() is changing. The default value for\n'
@@ -223,8 +221,11 @@ def _get_replica_device_setter(config):
if config.num_ps_replicas > 0:
return device_setter.replica_device_setter(
- ps_tasks=config.num_ps_replicas, worker_device=worker_device,
- merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
+ ps_tasks=config.num_ps_replicas,
+ worker_device=worker_device,
+ merge_devices=True,
+ ps_ops=ps_ops,
+ cluster=config.cluster_spec)
else:
return None
@@ -284,10 +285,10 @@ def _make_metrics_ops(metrics, features, labels, predictions):
raise ValueError('Invalid metric for {}. It returned a tuple with '
'len {}, expected 2.'.format(name, len(name)))
if not isinstance(predictions, dict):
- raise ValueError(
- 'Metrics passed provide (name, prediction), '
- 'but predictions are not dict. '
- 'Metrics: %s, Predictions: %s.' % (metrics, predictions))
+ raise ValueError('Metrics passed provide (name, prediction), '
+ 'but predictions are not dict. '
+ 'Metrics: %s, Predictions: %s.' % (metrics,
+ predictions))
# Here are two options: labels are single Tensor or a dict.
if isinstance(labels, dict) and name[1] in labels:
# If labels are dict and the prediction name is in it, apply metric.
@@ -298,10 +299,10 @@ def _make_metrics_ops(metrics, features, labels, predictions):
else:
# Single head metrics.
if isinstance(predictions, dict):
- raise ValueError(
- 'Metrics passed provide only name, no prediction, '
- 'but predictions are dict. '
- 'Metrics: %s, Labels: %s.' % (metrics, labels_tensor_or_dict))
+ raise ValueError('Metrics passed provide only name, no prediction, '
+ 'but predictions are dict. '
+ 'Metrics: %s, Labels: %s.' % (metrics,
+ labels_tensor_or_dict))
result[name] = metric(predictions, labels_tensor_or_dict)
return result
@@ -369,9 +370,8 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step):
logging.info(
'Summary for np.ndarray is not visible in Tensorboard by default. '
'Consider using a Tensorboard plugin for visualization (see '
- 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long
- 'for more information).'
- )
+ 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
+ ' for more information).')
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
@@ -385,8 +385,8 @@ GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec',
['tags', 'transforms'])
-class BaseEstimator(
- sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
+class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
+ trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
Users should not instantiate or subclass this class. Instead, use an
@@ -428,7 +428,7 @@ class BaseEstimator(
# necessary.
# pylint: disable=g-doc-exception
raise ValueError(
- "model_dir are set both in constructor and RunConfig, but with "
+ 'model_dir are set both in constructor and RunConfig, but with '
"different values. In constructor: '{}', in RunConfig: "
"'{}' ".format(model_dir, self._config.model_dir))
# pylint: enable=g-doc-exception
@@ -457,12 +457,16 @@ class BaseEstimator(
# TODO(wicke): make RunConfig immutable, and then return it without a copy.
return copy.deepcopy(self._config)
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
- def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
- monitors=None, max_steps=None):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
+ def fit(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ steps=None,
+ batch_size=None,
+ monitors=None,
+ max_steps=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Trainable`.
@@ -494,13 +498,15 @@ class BaseEstimator(
logging.info('Loss for final step: %s.', loss)
return self
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
- def partial_fit(
- self, x=None, y=None, input_fn=None, steps=1, batch_size=None,
- monitors=None):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
+ def partial_fit(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ steps=1,
+ batch_size=None,
+ monitors=None):
"""Incremental fit on a batch of samples.
This method is expected to be called several times consecutively
@@ -536,13 +542,16 @@ class BaseEstimator(
"""
logging.warning('The current implementation of partial_fit is not optimized'
' for use in a loop. Consider using fit() instead.')
- return self.fit(x=x, y=y, input_fn=input_fn, steps=steps,
- batch_size=batch_size, monitors=monitors)
+ return self.fit(
+ x=x,
+ y=y,
+ input_fn=input_fn,
+ steps=steps,
+ batch_size=batch_size,
+ monitors=monitors)
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
def evaluate(self,
x=None,
y=None,
@@ -584,13 +593,14 @@ class BaseEstimator(
eval_results.update({'global_step': global_step})
return eval_results
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('batch_size', None), ('as_iterable', True)
- )
- def predict(
- self, x=None, input_fn=None, batch_size=None, outputs=None,
- as_iterable=True):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('batch_size', None), ('as_iterable', True))
+ def predict(self,
+ x=None,
+ input_fn=None,
+ batch_size=None,
+ outputs=None,
+ as_iterable=True):
"""Returns predictions for given features.
Args:
@@ -651,16 +661,17 @@ class BaseEstimator(
return self._model_dir
@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
- def export(self,
- export_dir,
- input_fn=export._default_input_fn, # pylint: disable=protected-access
- input_feature_key=None,
- use_deprecated_input_fn=True,
- signature_fn=None,
- prediction_key=None,
- default_batch_size=1,
- exports_to_keep=None,
- checkpoint_path=None):
+ def export(
+ self,
+ export_dir,
+ input_fn=export._default_input_fn, # pylint: disable=protected-access
+ input_feature_key=None,
+ use_deprecated_input_fn=True,
+ signature_fn=None,
+ prediction_key=None,
+ default_batch_size=1,
+ exports_to_keep=None,
+ checkpoint_path=None):
"""Exports inference graph into given dir.
Args:
@@ -798,8 +809,8 @@ class BaseEstimator(
logging.debug('Setting feature info to %s.', str(self._features_info))
if labels is not None:
if self._labels_info is not None:
- logging.debug('Given labels: %s, required signatures: %s.',
- str(labels), str(self._labels_info))
+ logging.debug('Given labels: %s, required signatures: %s.', str(labels),
+ str(self._labels_info))
if not tensor_signature.tensors_compatible(labels, self._labels_info):
raise ValueError('Labels are incompatible with given information. '
'Given labels: %s, required signatures: %s.' %
@@ -850,13 +861,13 @@ class BaseEstimator(
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
if not latest_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
checkpoint_path = latest_path
# Setup output directory.
- eval_dir = os.path.join(self._model_dir, 'eval' if not name else
- 'eval_' + name)
+ eval_dir = os.path.join(self._model_dir, 'eval'
+ if not name else 'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -879,8 +890,7 @@ class BaseEstimator(
'Use steps=None if intended.')
if steps:
hooks.append(
- evaluation.StopAfterNEvalsHook(
- steps, log_progress=log_progress))
+ evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress))
global_step_key = 'global_step'
while global_step_key in eval_dict:
@@ -916,8 +926,8 @@ class BaseEstimator(
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -979,7 +989,8 @@ class BaseEstimator(
existing_keys = predictions.keys()
predictions = {
key: value
- for key, value in six.iteritems(predictions) if key in outputs
+ for key, value in six.iteritems(predictions)
+ if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
@@ -1045,8 +1056,7 @@ class BaseEstimator(
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=self._session_config
- ) as mon_sess:
+ config=self._session_config) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
@@ -1137,8 +1147,7 @@ class Estimator(BaseEstimator):
if params is not None and 'params' not in model_fn_args:
raise ValueError('Estimator\'s model_fn (%s) does not have a params '
'argument, but params (%s) were passed to the '
- 'Estimator\'s constructor.' %
- (model_fn, params))
+ 'Estimator\'s constructor.' % (model_fn, params))
if params is None and 'params' in model_fn_args:
logging.warning('Estimator\'s model_fn (%s) includes params '
'argument, but params are not passed to Estimator.',
@@ -1192,8 +1201,9 @@ class Estimator(BaseEstimator):
# Custom metrics should overwrite defaults.
if metrics:
- model_fn_ops.eval_metric_ops.update(_make_metrics_ops(
- metrics, features, labels, model_fn_ops.predictions))
+ model_fn_ops.eval_metric_ops.update(
+ _make_metrics_ops(metrics, features, labels,
+ model_fn_ops.predictions))
return model_fn_ops
@@ -1238,8 +1248,8 @@ class Estimator(BaseEstimator):
Raises:
ValueError: if `metrics` don't match `labels`.
"""
- model_fn_ops = self._call_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL, metrics)
+ model_fn_ops = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, metrics)
if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops:
model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = (
@@ -1263,14 +1273,16 @@ class Estimator(BaseEstimator):
self._labels_info)
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
- def export_savedmodel(
- self, export_dir_base, serving_input_fn,
- default_output_alternative_key=None,
- assets_extra=None,
- as_text=False,
- checkpoint_path=None,
- graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),),
- strip_default_attrs=False):
+ def export_savedmodel(self,
+ export_dir_base,
+ serving_input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ graph_rewrite_specs=(GraphRewriteSpec(
+ (tag_constants.SERVING,), ()),),
+ strip_default_attrs=False):
# pylint: disable=line-too-long
"""Exports inference graph as a SavedModel into given dir.
@@ -1297,7 +1309,8 @@ class Estimator(BaseEstimator):
default serving tag ("serve") and no rewriting.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ [Stripping Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
@@ -1313,8 +1326,8 @@ class Estimator(BaseEstimator):
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
@@ -1348,10 +1361,10 @@ class Estimator(BaseEstimator):
saved_model_export_utils.get_output_alternatives(
model_fn_ops, default_output_alternative_key))
- init_op = control_flow_ops.group(
- variables.local_variables_initializer(),
- resources.initialize_resources(resources.shared_resources()),
- lookup_ops.tables_initializer())
+ init_op = control_flow_ops.group(variables.local_variables_initializer(),
+ resources.initialize_resources(
+ resources.shared_resources()),
+ lookup_ops.tables_initializer())
# Build the SignatureDefs from all pairs of input and output alternatives
signature_def_map = saved_model_export_utils.build_all_signature_defs(
@@ -1381,10 +1394,10 @@ class Estimator(BaseEstimator):
# TODO(soergel): switch to main_op or otherwise update when dust settles
builder.add_meta_graph_and_variables(
- session, untransformed_tags,
+ session,
+ untransformed_tags,
signature_def_map=signature_def_map,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS),
+ assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op,
strip_default_attrs=strip_default_attrs)
@@ -1395,12 +1408,16 @@ class Estimator(BaseEstimator):
if graph_rewrite_specs[1:]:
# Prepare the input_names and output_names needed for the
# meta_graph_transform call below.
- input_names = [tensor.name
- for input_dict in input_alternatives.values()
- for tensor in input_dict.values()]
- output_names = [tensor.name
- for output_alternative in output_alternatives.values()
- for tensor in output_alternative[1].values()]
+ input_names = [
+ tensor.name
+ for input_dict in input_alternatives.values()
+ for tensor in input_dict.values()
+ ]
+ output_names = [
+ tensor.name
+ for output_alternative in output_alternatives.values()
+ for tensor in output_alternative[1].values()
+ ]
# Write the additional MetaGraphDefs
for graph_rewrite_spec in graph_rewrite_specs[1:]:
@@ -1419,11 +1436,11 @@ class Estimator(BaseEstimator):
# Add the extra assets
if assets_extra:
- assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
- compat.as_bytes('assets.extra'))
+ assets_extra_path = os.path.join(
+ compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
- dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
- compat.as_bytes(dest_relative))
+ dest_absolute = os.path.join(
+ compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
gfile.MakeDirs(dest_path)
gfile.Copy(source, dest_absolute)
@@ -1443,25 +1460,36 @@ class SKCompat(sklearn.BaseEstimator):
def fit(self, x, y, batch_size=128, steps=None, max_steps=None,
monitors=None):
- input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None,
- batch_size=batch_size, shuffle=True,
- epochs=None)
+ input_fn, feed_fn = _get_input_fn(
+ x,
+ y,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=True,
+ epochs=None)
all_monitors = []
if feed_fn:
all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)]
if monitors:
all_monitors.extend(monitors)
- self._estimator.fit(input_fn=input_fn,
- steps=steps,
- max_steps=max_steps,
- monitors=all_monitors)
+ self._estimator.fit(
+ input_fn=input_fn,
+ steps=steps,
+ max_steps=max_steps,
+ monitors=all_monitors)
return self
def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None):
- input_fn, feed_fn = _get_input_fn(x, y, input_fn=None,
- feed_fn=None, batch_size=batch_size,
- shuffle=False, epochs=1)
+ input_fn, feed_fn = _get_input_fn(
+ x,
+ y,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=False,
+ epochs=1)
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
@@ -1477,8 +1505,13 @@ class SKCompat(sklearn.BaseEstimator):
def predict(self, x, batch_size=128, outputs=None):
input_fn, feed_fn = _get_input_fn(
- x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
- shuffle=False, epochs=1)
+ x,
+ None,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=False,
+ epochs=1)
results = list(
self._estimator._infer_model(
input_fn=input_fn,
@@ -1489,7 +1522,6 @@ class SKCompat(sklearn.BaseEstimator):
if not isinstance(results[0], dict):
return np.concatenate([output for output in results], axis=0)
return {
- key: np.concatenate(
- [output[key] for output in results], axis=0)
+ key: np.concatenate([output[key] for output in results], axis=0)
for key in results[0]
}
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 5f682838b7..d81a534b79 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -111,8 +111,8 @@ def boston_eval_fn():
constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
labels = array_ops.reshape(
constant_op.constant(boston.target), [n_examples, 1])
- return array_ops.concat([features, features], 0), array_ops.concat(
- [labels, labels], 0)
+ return array_ops.concat([features, features],
+ 0), array_ops.concat([labels, labels], 0)
def extract(data, key):
@@ -147,7 +147,10 @@ def linear_model_fn(features, labels, mode):
(_, features), = features.items()
prediction, loss = (models.linear_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return prediction, loss, train_op
@@ -157,7 +160,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode):
model_fn.ModeKeys.INFER)
prediction, loss = (models.linear_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return model_fn.ModelFnOps(
mode=mode, predictions=prediction, loss=loss, train_op=train_op)
@@ -168,7 +174,10 @@ def logistic_model_no_mode_fn(features, labels):
labels = array_ops.one_hot(labels, 3, 1, 0)
prediction, loss = (models.logistic_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return {
'class': math_ops.argmax(prediction, 1),
'prob': prediction
@@ -184,14 +193,12 @@ def _build_estimator_for_export_tests(tmpdir):
def _input_fn():
iris = base.load_iris()
return {
- 'feature': constant_op.constant(
- iris.data, dtype=dtypes.float32)
+ 'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
}, constant_op.constant(
iris.target, shape=[150], dtype=dtypes.int32)
feature_columns = [
- feature_column_lib.real_valued_column(
- 'feature', dimension=4)
+ feature_column_lib.real_valued_column('feature', dimension=4)
]
est = linear.LinearRegressor(feature_columns)
@@ -291,8 +298,8 @@ class CheckCallsMonitor(monitors_lib.BaseMonitor):
self.begin_calls == self.expect_calls)
-def _model_fn_ops(
- expected_features, expected_labels, actual_features, actual_labels, mode):
+def _model_fn_ops(expected_features, expected_labels, actual_features,
+ actual_labels, mode):
assert_ops = tuple([
check_ops.assert_equal(
expected_features[k], actual_features[k], name='assert_%s' % k)
@@ -310,11 +317,11 @@ def _model_fn_ops(
def _make_input_fn(features, labels):
+
def _input_fn():
- return {
- k: constant_op.constant(v)
- for k, v in six.iteritems(features)
- }, constant_op.constant(labels)
+ return {k: constant_op.constant(v)
+ for k, v in six.iteritems(features)}, constant_op.constant(labels)
+
return _input_fn
@@ -369,11 +376,13 @@ class EstimatorModelFnTest(test.TestCase):
self.assertEqual(expected_params, params)
self.assertTrue(config.i_am_test)
return _model_fn_ops(features, labels, arg0, arg1, mode)
+
partial_model_fn = functools.partial(
_model_fn, foo=expected_foo, bar=expected_bar)
est = estimator.Estimator(
- model_fn=partial_model_fn, params=expected_params,
+ model_fn=partial_model_fn,
+ params=expected_params,
config=expected_config)
self.assertEqual(0, model_fn_call_count[0])
est.fit(input_fn=_make_input_fn(features, labels), steps=1)
@@ -382,7 +391,12 @@ class EstimatorModelFnTest(test.TestCase):
def testModelFnWithModelDir(self):
expected_param = {'some_param': 'some_value'}
expected_model_dir = tempfile.mkdtemp()
- def _argument_checker(features, labels, mode, params, config=None,
+
+ def _argument_checker(features,
+ labels,
+ mode,
+ params,
+ config=None,
model_dir=None):
_, _, _ = features, labels, config
self.assertEqual(model_fn.ModeKeys.TRAIN, mode)
@@ -390,9 +404,11 @@ class EstimatorModelFnTest(test.TestCase):
self.assertEqual(model_dir, expected_model_dir)
return (constant_op.constant(0.), constant_op.constant(0.),
training_util.get_global_step().assign_add(1))
- est = estimator.Estimator(model_fn=_argument_checker,
- params=expected_param,
- model_dir=expected_model_dir)
+
+ est = estimator.Estimator(
+ model_fn=_argument_checker,
+ params=expected_param,
+ model_dir=expected_model_dir)
est.fit(input_fn=boston_input_fn, steps=1)
def testInvalidModelFn_no_train_op(self):
@@ -447,8 +463,7 @@ class EstimatorModelFnTest(test.TestCase):
est.predict(input_fn=boston_input_fn)
with self.assertRaisesRegexp(ValueError, 'Missing prediction'):
est.predict(
- input_fn=functools.partial(
- boston_input_fn, num_epochs=1),
+ input_fn=functools.partial(boston_input_fn, num_epochs=1),
as_iterable=True)
def testModelFnScaffoldInTraining(self):
@@ -498,15 +513,17 @@ class EstimatorModelFnTest(test.TestCase):
self.assertTrue(self.mock_saver.restore.called)
est.predict(input_fn=input_fn)
self.assertTrue(self.mock_saver.restore.called)
+
def serving_input_fn():
- serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
- shape=[None],
- name='input_example_tensor')
+ serialized_tf_example = array_ops.placeholder(
+ dtype=dtypes.string, shape=[None], name='input_example_tensor')
features, labels = input_fn()
- return input_fn_utils.InputFnOps(
- features, labels, {'examples': serialized_tf_example})
+ return input_fn_utils.InputFnOps(features, labels, {
+ 'examples': serialized_tf_example
+ })
- est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
+ est.export_savedmodel(
+ os.path.join(est.model_dir, 'export'), serving_input_fn)
self.assertTrue(self.mock_saver.restore.called)
@@ -550,33 +567,28 @@ class EstimatorTest(test.TestCase):
def testRunConfigModelDir(self):
config = run_config.RunConfig(model_dir='test_dir')
- est = estimator.Estimator(model_fn=linear_model_fn,
- config=config)
+ est = estimator.Estimator(model_fn=linear_model_fn, config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
def testModelDirAndRunConfigModelDir(self):
config = run_config.RunConfig(model_dir='test_dir')
- est = estimator.Estimator(model_fn=linear_model_fn,
- config=config,
- model_dir='test_dir')
+ est = estimator.Estimator(
+ model_fn=linear_model_fn, config=config, model_dir='test_dir')
self.assertEqual('test_dir', est.config.model_dir)
with self.assertRaisesRegexp(
- ValueError,
- 'model_dir are set both in constructor and RunConfig, '
+ ValueError, 'model_dir are set both in constructor and RunConfig, '
'but with different'):
- estimator.Estimator(model_fn=linear_model_fn,
- config=config,
- model_dir='different_dir')
+ estimator.Estimator(
+ model_fn=linear_model_fn, config=config, model_dir='different_dir')
def testModelDirIsCopiedToRunConfig(self):
config = run_config.RunConfig()
self.assertIsNone(config.model_dir)
- est = estimator.Estimator(model_fn=linear_model_fn,
- model_dir='test_dir',
- config=config)
+ est = estimator.Estimator(
+ model_fn=linear_model_fn, model_dir='test_dir', config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
@@ -656,25 +668,27 @@ class EstimatorTest(test.TestCase):
boston = base.load_boston()
output_dir = tempfile.mkdtemp()
est = estimator.SKCompat(
- estimator.Estimator(
- model_fn=linear_model_fn, model_dir=output_dir))
+ estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
float64_labels = boston.target.astype(np.float64)
est.fit(x=boston.data, y=float64_labels, steps=50)
scores = est.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
del est
# Create another estimator object with the same output dir.
est2 = estimator.SKCompat(
- estimator.Estimator(
- model_fn=linear_model_fn, model_dir=output_dir))
+ estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
# Check we can evaluate and predict.
scores2 = est2.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
self.assertAllClose(scores['MSE'], scores2['MSE'])
predictions = np.array(list(est2.predict(x=boston.data)))
other_score = _sklearn.mean_squared_error(predictions, float64_labels)
@@ -685,14 +699,15 @@ class EstimatorTest(test.TestCase):
scores3 = est2.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
self.assertLess(scores3['MSE'], scores['MSE'])
def test_checkpoint_contains_relative_paths(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
- model_dir=tmpdir,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops)
est.fit(input_fn=boston_input_fn, steps=5)
checkpoint_file_content = file_io.read_file_to_string(
@@ -700,22 +715,20 @@ class EstimatorTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
- self.assertAllEqual(
- ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'],
+ ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
model_dir1 = os.path.join(tmpdir, 'model_dir1')
est1 = estimator.Estimator(
- model_dir=model_dir1,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops)
est1.fit(input_fn=boston_input_fn, steps=5)
model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2)
est2 = estimator.Estimator(
- model_dir=model_dir2,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops)
self.assertEqual(5, est2.get_variable_value('global_step'))
est2.fit(input_fn=boston_input_fn, steps=5)
self.assertEqual(10, est2.get_variable_value('global_step'))
@@ -724,7 +737,9 @@ class EstimatorTest(test.TestCase):
boston = base.load_boston()
est = estimator.SKCompat(
estimator.Estimator(
- model_fn=linear_model_params_fn, params={'learning_rate': 0.01}))
+ model_fn=linear_model_params_fn, params={
+ 'learning_rate': 0.01
+ }))
est.fit(x=boston.data, y=boston.target, steps=100)
def testHooksNotChanged(self):
@@ -824,11 +839,13 @@ class EstimatorTest(test.TestCase):
def testMonitorsForFit(self):
est = estimator.Estimator(model_fn=linear_model_fn)
- est.fit(input_fn=boston_input_fn,
- steps=21,
- monitors=[CheckCallsMonitor(expect_calls=21)])
+ est.fit(
+ input_fn=boston_input_fn,
+ steps=21,
+ monitors=[CheckCallsMonitor(expect_calls=21)])
def testHooksForEvaluate(self):
+
class CheckCallHook(session_run_hook.SessionRunHook):
def __init__(self):
@@ -874,7 +891,9 @@ class EstimatorTest(test.TestCase):
est.evaluate(
input_fn=boston_input_fn,
steps=200,
- metrics={'MSE': _streaming_mean_squared_error_histogram})
+ metrics={
+ 'MSE': _streaming_mean_squared_error_histogram
+ })
events = util_test.latest_events(est.model_dir + '/eval')
output_values = {}
for e in events:
@@ -903,7 +922,9 @@ class EstimatorTest(test.TestCase):
est.evaluate(
input_fn=boston_input_fn,
steps=200,
- metrics={'PMT': _streaming_precition_mean_tensor})
+ metrics={
+ 'PMT': _streaming_precition_mean_tensor
+ })
events = util_test.latest_events(est.model_dir + '/eval')
output_values = {}
for e in events:
@@ -956,8 +977,8 @@ class EstimatorTest(test.TestCase):
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1017,11 +1038,11 @@ class EstimatorTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
- self.assertItemsEqual(
- ['bogus_lookup', 'feature'],
- [compat.as_str_any(x) for x in graph.get_collection(
- constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
-
+ self.assertItemsEqual(['bogus_lookup', 'feature'], [
+ compat.as_str_any(x)
+ for x in graph.get_collection(
+ constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)
+ ])
# cleanup
gfile.DeleteRecursively(tmpdir)
@@ -1039,8 +1060,8 @@ class EstimatorTest(test.TestCase):
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1083,19 +1104,22 @@ class EstimatorTest(test.TestCase):
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
export_dir = est.export_savedmodel(
- export_dir_base, serving_input_fn, assets_extra=assets_extra,
+ export_dir_base,
+ serving_input_fn,
+ assets_extra=assets_extra,
graph_rewrite_specs=[
estimator.GraphRewriteSpec(['tag_1'], []),
estimator.GraphRewriteSpec(['tag_2', 'tag_3'],
- ['strip_unused_nodes'])])
+ ['strip_unused_nodes'])
+ ])
self.assertTrue(gfile.Exists(export_dir_base))
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1208,18 +1232,15 @@ class InferRealValuedColumnsTest(test.TestCase):
self.assertEqual(1, len(feature_columns))
feature_column = feature_columns[0]
self.assertEqual('', feature_column.name)
- self.assertEqual(
- {
- '':
- parsing_ops.FixedLenFeature(
- shape=expected_shape, dtype=expected_dtype)
- },
- feature_column.config)
+ self.assertEqual({
+ '':
+ parsing_ops.FixedLenFeature(
+ shape=expected_shape, dtype=expected_dtype)
+ }, feature_column.config)
def testInt32Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.int32))
+ np.ones(shape=[7, 8], dtype=np.int32))
self._assert_single_feature_column([8], dtypes.int32, feature_columns)
def testInt32InputFn(self):
@@ -1229,8 +1250,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testInt64Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.int64))
+ np.ones(shape=[7, 8], dtype=np.int64))
self._assert_single_feature_column([8], dtypes.int64, feature_columns)
def testInt64InputFn(self):
@@ -1240,8 +1260,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testFloat32Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.float32))
+ np.ones(shape=[7, 8], dtype=np.float32))
self._assert_single_feature_column([8], dtypes.float32, feature_columns)
def testFloat32InputFn(self):
@@ -1251,8 +1270,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testFloat64Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.float64))
+ np.ones(shape=[7, 8], dtype=np.float64))
self._assert_single_feature_column([8], dtypes.float64, feature_columns)
def testFloat64InputFn(self):
@@ -1271,8 +1289,8 @@ class InferRealValuedColumnsTest(test.TestCase):
ValueError, 'on integer or non floating types are not supported'):
# pylint: disable=g-long-lambda
estimator.infer_real_valued_columns_from_input_fn(
- lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool),
- None))
+ lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None)
+ )
def testStringInput(self):
with self.assertRaisesRegexp(
@@ -1309,8 +1327,9 @@ class ReplicaDeviceSetterTest(test.TestCase):
def testVariablesAreOnPs(self):
tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
@@ -1337,14 +1356,14 @@ class ReplicaDeviceSetterTest(test.TestCase):
def testMutableHashTableIsOnPs(self):
tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
default_val = constant_op.constant([-1, -1], dtypes.int64)
- table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
- default_val)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
input_string = constant_op.constant(['brain', 'salad', 'tank'])
output = table.lookup(input_string)
self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device)
@@ -1354,8 +1373,7 @@ class ReplicaDeviceSetterTest(test.TestCase):
with ops.device(
estimator._get_replica_device_setter(run_config.RunConfig())):
default_val = constant_op.constant([-1, -1], dtypes.int64)
- table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
- default_val)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
input_string = constant_op.constant(['brain', 'salad', 'tank'])
output = table.lookup(input_string)
self.assertDeviceEqual('', table._table_ref.device)
@@ -1371,8 +1389,9 @@ class ReplicaDeviceSetterTest(test.TestCase):
'index': 3
}
}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
index 8131e0fde6..2113fae394 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
@@ -72,9 +72,11 @@ class FeatureEngineeringFunctionTest(test.TestCase):
# predictions = transformed_x (9)
self.assertEqual(9., prediction)
metrics = estimator.evaluate(
- input_fn=input_fn, steps=1,
- metrics={"label":
- metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ input_fn=input_fn,
+ steps=1,
+ metrics={
+ "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
+ })
# labels = transformed_y (99)
self.assertEqual(99., metrics["label"])
@@ -82,10 +84,10 @@ class FeatureEngineeringFunctionTest(test.TestCase):
def input_fn():
return {
- "x": constant_op.constant(["9."])
- }, {
- "y": constant_op.constant(["99."])
- }
+ "x": constant_op.constant(["9."])
+ }, {
+ "y": constant_op.constant(["99."])
+ }
def feature_engineering_fn(features, labels):
# Github #12205: raise a TypeError if called twice.
@@ -104,15 +106,17 @@ class FeatureEngineeringFunctionTest(test.TestCase):
return predictions, loss, update_global_step
estimator = estimator_lib.Estimator(
- model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
+ model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
estimator.fit(input_fn=input_fn, steps=1)
prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
# predictions = transformed_x (9)
self.assertEqual(9., prediction)
metrics = estimator.evaluate(
- input_fn=input_fn, steps=1,
- metrics={"label":
- metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ input_fn=input_fn,
+ steps=1,
+ metrics={
+ "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
+ })
# labels = transformed_y (99)
self.assertEqual(99., metrics["label"])
@@ -150,12 +154,10 @@ class FeatureEngineeringFunctionTest(test.TestCase):
# predictions = x
prediction_with_fe_fn = next(
- estimator_with_fe_fn.predict(
- input_fn=input_fn, as_iterable=True))
+ estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True))
self.assertEqual(9., prediction_with_fe_fn)
prediction_without_fe_fn = next(
- estimator_without_fe_fn.predict(
- input_fn=input_fn, as_iterable=True))
+ estimator_without_fe_fn.predict(input_fn=input_fn, as_iterable=True))
self.assertEqual(1., prediction_without_fe_fn)
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 3e0b1ad21a..0948dee7e2 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Monitors instrument the training process.
@@get_default_monitors
@@ -151,8 +150,8 @@ class BaseMonitor(object):
ValueError: if we've not begun an epoch, or `epoch` number does not match.
"""
if self._current_epoch != epoch:
- raise ValueError(
- "epoch_end expected %s but got %s.", self._current_epoch, epoch)
+ raise ValueError("epoch_end expected %s but got %s.", self._current_epoch,
+ epoch)
self._current_epoch = None
def step_begin(self, step):
@@ -171,8 +170,8 @@ class BaseMonitor(object):
ValueError: if we've already begun a step, or `step` < 0, or
`step` > `max_steps`.
"""
- if (step < 0) or (
- (self._max_steps is not None) and (step > self._max_steps)):
+ if (step < 0) or ((self._max_steps is not None) and
+ (step > self._max_steps)):
raise ValueError("Invalid step %s." % step)
self._current_step = step
return []
@@ -203,8 +202,8 @@ class BaseMonitor(object):
ValueError: if we've not begun a step, or `step` number does not match.
"""
if self._current_step != step:
- raise ValueError(
- "step_end expected %s but got %s.", self._current_step, step)
+ raise ValueError("step_end expected %s but got %s.", self._current_step,
+ step)
self._current_step = None
return False
@@ -253,6 +252,7 @@ class EveryN(BaseMonitor):
treatment.
"""
+
# TODO(ipolosukhin): Add also every n seconds.
def __init__(self, every_n_steps=100, first_n_steps=1):
@@ -475,8 +475,8 @@ class LoggingTrainable(EveryN):
super(LoggingTrainable, self).every_n_step_begin(step)
# Get a list of trainable variables at the beginning of every N steps.
# We cannot get this in __init__ because train_op has not been generated.
- trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=self._scope)
+ trainables = ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope)
self._names = {}
for var in trainables:
self._names[var.name] = var.value().name
@@ -561,12 +561,19 @@ class ValidationMonitor(EveryN):
provided.
"""
- def __init__(self, x=None, y=None, input_fn=None, batch_size=None,
+ def __init__(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ batch_size=None,
eval_steps=None,
- every_n_steps=100, metrics=None, hooks=None,
+ every_n_steps=100,
+ metrics=None,
+ hooks=None,
early_stopping_rounds=None,
early_stopping_metric="loss",
- early_stopping_metric_minimize=True, name=None):
+ early_stopping_metric_minimize=True,
+ name=None):
"""Initializes a ValidationMonitor.
Args:
@@ -597,8 +604,8 @@ class ValidationMonitor(EveryN):
Raises:
ValueError: If both x and input_fn are provided.
"""
- super(ValidationMonitor, self).__init__(every_n_steps=every_n_steps,
- first_n_steps=-1)
+ super(ValidationMonitor, self).__init__(
+ every_n_steps=every_n_steps, first_n_steps=-1)
# TODO(mdan): Checks like this are already done by evaluate.
if x is None and input_fn is None:
raise ValueError("Either x or input_fn should be provided.")
@@ -654,20 +661,27 @@ class ValidationMonitor(EveryN):
def _evaluate_estimator(self):
if isinstance(self._estimator, core_estimator.Estimator):
- if any((x is not None for x in
- [self.x, self.y, self.batch_size, self.metrics])):
+ if any((x is not None
+ for x in [self.x, self.y, self.batch_size, self.metrics])):
raise ValueError(
"tf.estimator.Estimator does not support following "
"arguments: x, y, batch_size, metrics. Should set as `None` "
"in ValidationMonitor")
return self._estimator.evaluate(
- input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks,
+ input_fn=self.input_fn,
+ steps=self.eval_steps,
+ hooks=self.hooks,
name=self.name)
else:
return self._estimator.evaluate(
- x=self.x, y=self.y, input_fn=self.input_fn,
- batch_size=self.batch_size, steps=self.eval_steps,
- metrics=self.metrics, hooks=self.hooks, name=self.name)
+ x=self.x,
+ y=self.y,
+ input_fn=self.input_fn,
+ batch_size=self.batch_size,
+ steps=self.eval_steps,
+ metrics=self.metrics,
+ hooks=self.hooks,
+ name=self.name)
def every_n_step_end(self, step, outputs):
super(ValidationMonitor, self).every_n_step_end(step, outputs)
@@ -700,8 +714,9 @@ class ValidationMonitor(EveryN):
# Early stopping logic.
if self.early_stopping_rounds is not None:
if self.early_stopping_metric not in validation_outputs:
- raise ValueError("Metric %s missing from outputs %s." % (
- self.early_stopping_metric, set(validation_outputs.keys())))
+ raise ValueError("Metric %s missing from outputs %s." %
+ (self.early_stopping_metric,
+ set(validation_outputs.keys())))
current_value = validation_outputs[self.early_stopping_metric]
if (self._best_value is None or (self.early_stopping_metric_minimize and
(current_value < self._best_value)) or
@@ -712,9 +727,9 @@ class ValidationMonitor(EveryN):
self._best_value_step = step
stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
if stop_now:
- logging.info("Stopping. Best step: {} with {} = {}."
- .format(self._best_value_step,
- self.early_stopping_metric, self._best_value))
+ logging.info("Stopping. Best step: {} with {} = {}.".format(
+ self._best_value_step, self.early_stopping_metric,
+ self._best_value))
self._early_stopped = True
return True
return False
@@ -763,8 +778,11 @@ class CaptureVariable(EveryN):
self._var_values[step] = _extract_output(outputs, self._var_name)
-def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100,
- output_dir=None, summary_writer=None):
+def get_default_monitors(loss_op=None,
+ summary_op=None,
+ save_summary_steps=100,
+ output_dir=None,
+ summary_writer=None):
"""Returns a default set of typically-used monitors.
Args:
@@ -782,9 +800,12 @@ def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100,
if loss_op is not None:
monitors.append(PrintTensor(tensor_names={"loss": loss_op.name}))
if summary_op is not None:
- monitors.append(SummarySaver(summary_op, save_steps=save_summary_steps,
- output_dir=output_dir,
- summary_writer=summary_writer))
+ monitors.append(
+ SummarySaver(
+ summary_op,
+ save_steps=save_summary_steps,
+ output_dir=output_dir,
+ summary_writer=summary_writer))
return monitors
@@ -794,8 +815,10 @@ class GraphDump(BaseMonitor):
Note, this is very expensive, prefer `PrintTensor` in production.
"""
- IGNORE_OPS = ["Const", "Assign", "Identity", "Placeholder",
- "RandomUniform", "Cast", "RestoreSlice"]
+ IGNORE_OPS = [
+ "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast",
+ "RestoreSlice"
+ ]
def __init__(self, ignore_ops=None):
"""Initializes GraphDump monitor.
@@ -881,8 +904,8 @@ class ExportMonitor(EveryN):
"""Monitor that exports Estimator every N steps."""
@deprecation.deprecated("2017-03-25",
- "ExportMonitor is deprecated. Please pass an "
- "ExportStrategy to Experiment instead.")
+ "ExportMonitor is deprecated. Please pass an "
+ "ExportStrategy to Experiment instead.")
def __init__(self,
every_n_steps,
export_dir,
@@ -1088,8 +1111,7 @@ class CheckpointSaver(BaseMonitor):
class StepCounter(EveryN):
"""Steps per second monitor."""
- def __init__(self, every_n_steps=100, output_dir=None,
- summary_writer=None):
+ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None):
super(StepCounter, self).__init__(every_n_steps=every_n_steps)
self._summary_tag = "global_step/sec"
self._last_reported_step = None
@@ -1101,7 +1123,8 @@ class StepCounter(EveryN):
def set_estimator(self, estimator):
super(StepCounter, self).set_estimator(estimator)
if self._summary_writer is None:
- self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir)
+ self._summary_writer = core_summary.FileWriterCache.get(
+ estimator.model_dir)
def every_n_step_end(self, current_step, outputs):
current_time = time.time()
@@ -1109,8 +1132,9 @@ class StepCounter(EveryN):
added_steps = current_step - self._last_reported_step
elapsed_time = current_time - self._last_reported_time
steps_per_sec = added_steps / elapsed_time
- summary = Summary(value=[Summary.Value(tag=self._summary_tag,
- simple_value=steps_per_sec)])
+ summary = Summary(value=[
+ Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
+ ])
self._summary_writer.add_summary(summary, current_step)
self._last_reported_step = current_step
self._last_reported_time = current_time
diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py
index 95070ada3b..9bfb1fc952 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py
@@ -50,6 +50,7 @@ def _training_input_fn():
class ExportTest(test.TestCase):
+
def _get_default_signature(self, export_meta_filename):
""" Gets the default signature from the export.meta file. """
with session.Session():
@@ -69,18 +70,18 @@ class ExportTest(test.TestCase):
# Only the written checkpoints are exported.
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
- 'Exported checkpoint expected but not found: %s' %
- os.path.join(export_dir, '00000001', 'export'))
+ 'Exported checkpoint expected but not found: %s' % os.path.join(
+ export_dir, '00000001', 'export'))
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
- 'Exported checkpoint expected but not found: %s' %
- os.path.join(export_dir, '00000010', 'export'))
+ 'Exported checkpoint expected but not found: %s' % os.path.join(
+ export_dir, '00000010', 'export'))
self.assertEquals(
six.b(os.path.join(export_dir, '00000010')),
export_monitor.last_export_dir)
# Validate the signature
signature = self._get_default_signature(
- os.path.join(export_dir, '00000010', 'export.meta'))
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField(expected_signature))
def testExportMonitor_EstimatorProvidesSignature(self):
@@ -116,8 +117,7 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
input_feature_key = 'my_example_key'
@@ -160,8 +160,7 @@ class ExportTest(test.TestCase):
input_feature_key:
None,
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
monitor = learn.monitors.ExportMonitor(
@@ -182,8 +181,7 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
input_feature_key:
- array_ops.placeholder(
- dtype=dtypes.string, shape=(1,))
+ array_ops.placeholder(dtype=dtypes.string, shape=(1,))
}, None
monitor = learn.monitors.ExportMonitor(
@@ -204,11 +202,9 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
input_feature_key:
- array_ops.placeholder(
- dtype=dtypes.string, shape=(1,)),
+ array_ops.placeholder(dtype=dtypes.string, shape=(1,)),
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
@@ -227,8 +223,8 @@ class ExportTest(test.TestCase):
def _regression_signature(examples, unused_features, predictions):
signatures = {}
- signatures['regression'] = (exporter.regression_signature(examples,
- predictions))
+ signatures['regression'] = (
+ exporter.regression_signature(examples, predictions))
return signatures['regression'], signatures
random.seed(42)
@@ -248,10 +244,10 @@ class ExportTest(test.TestCase):
with self.assertRaises(errors.NotFoundError):
saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
self.assertTrue(
- saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
+ saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
# Validate the signature
signature = self._get_default_signature(
- os.path.join(export_dir, '00000010', 'export.meta'))
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField('regression_signature'))
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
index 76cfd88e1d..e7d091e18a 100644
--- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
@@ -34,12 +34,13 @@ def _create_parser(base_dir):
# create a simple parser that pulls the export_version from the directory.
def parser(path):
# Modify the path object for RegEx match for Windows Paths
- if os.name == 'nt':
- match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
- compat.as_str_any(path.path).replace('\\','/'))
+ if os.name == "nt":
+ match = re.match(
+ "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$",
+ compat.as_str_any(path.path).replace("\\", "/"))
else:
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
- compat.as_str_any(path.path))
+ compat.as_str_any(path.path))
if not match:
return None
return path._replace(export_version=int(match.group(1)))
@@ -63,7 +64,9 @@ class GcTest(test_util.TensorFlowTestCase):
def testModExportVersion(self):
paths = [
- gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 4),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc.mod_export_version(2)
@@ -73,14 +76,21 @@ class GcTest(test_util.TensorFlowTestCase):
def testOneOfEveryNExportVersions(self):
paths = [
- gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
- gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
- gc.Path("/foo", 8), gc.Path("/foo", 33)
+ gc.Path("/foo", 0),
+ gc.Path("/foo", 1),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 7),
+ gc.Path("/foo", 8),
+ gc.Path("/foo", 33)
]
one_of = gc.one_of_every_n_export_versions(3)
self.assertEqual(
one_of(paths), [
- gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 8),
gc.Path("/foo", 33)
])
@@ -98,13 +108,19 @@ class GcTest(test_util.TensorFlowTestCase):
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
self.assertEqual(
f(paths), [
- gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
- gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
+ gc.Path("/foo", 0),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 7),
+ gc.Path("/foo", 8),
+ gc.Path("/foo", 9)
])
def testNegation(self):
paths = [
- gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 4),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc.negation(gc.mod_export_version(2))
@@ -121,8 +137,7 @@ class GcTest(test_util.TensorFlowTestCase):
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
self.assertEqual(
- gc.get_paths(base_dir, _create_parser(base_dir)),
- [
+ gc.get_paths(base_dir, _create_parser(base_dir)), [
gc.Path(os.path.join(base_dir, "0"), 0),
gc.Path(os.path.join(base_dir, "1"), 1),
gc.Path(os.path.join(base_dir, "2"), 2)
@@ -131,10 +146,10 @@ class GcTest(test_util.TensorFlowTestCase):
def testMixedStrTypes(self):
temp_dir = compat.as_bytes(test.get_temp_dir())
- for sub_dir in ['str', b'bytes', u'unicode']:
+ for sub_dir in ["str", b"bytes", u"unicode"]:
base_dir = os.path.join(
- (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()),
- sub_dir)
+ (temp_dir
+ if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir)
self.assertFalse(gfile.Exists(base_dir))
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc.get_paths(base_dir, _create_parser(base_dir))
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index ee8a7ccd0b..68aee2e644 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Main abstraction controlling the tflite interpreter.
// See context.h for the API for defining operations (TfLiteRegistration).
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+#define TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
#include <cstdio>
#include <cstdlib>
@@ -91,4 +91,4 @@ class MemoryAllocation : public Allocation {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
+#endif // TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index bd87414ec3..58bc164619 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
+#define TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
#include <memory>
#include <vector>
@@ -104,4 +104,4 @@ class ArenaPlanner : public MemoryPlanner {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 0a097d5a69..19829e4991 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -5,25 +5,25 @@ def tflite_copts():
copts = [
"-DFARMHASH_NO_CXX_STRING",
] + select({
- "//tensorflow:android_arm64": [
+ str(Label("//tensorflow:android_arm64")): [
"-std=c++11",
"-O3",
],
- "//tensorflow:android_arm": [
+ str(Label("//tensorflow:android_arm")): [
"-mfpu=neon",
"-mfloat-abi=softfp",
"-std=c++11",
"-O3",
],
- "//tensorflow:android_x86": [
+ str(Label("//tensorflow:android_x86")): [
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
],
- "//tensorflow:ios_x86_64": [
+ str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_default_optimizations": [],
+ str(Label("//tensorflow:with_default_optimizations")): [],
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
})
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 0c333f9e8c..8338fde8ac 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
#include <stdint.h>
@@ -83,7 +83,14 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteRNNParams;
-typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams;
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteFullyConnectedParams;
typedef enum {
kTfLiteLshProjectionUnknown = 0,
@@ -91,9 +98,13 @@ typedef enum {
kTfLiteLshProjectionDense = 2,
} TfLiteLSHProjectionType;
-typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams;
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
-typedef struct { float beta; } TfLiteSoftmaxParams;
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
typedef struct {
int axis;
@@ -156,16 +167,9 @@ typedef struct {
} TfLiteLSTMParams;
typedef struct {
- int new_height;
- int new_width;
} TfLiteResizeBilinearParams;
typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int before_padding[8];
- int after_padding[8];
- int num_dimensions;
} TfLitePadParams;
typedef struct {
@@ -221,8 +225,16 @@ typedef struct {
int num_squeeze_dims;
} TfLiteSqueezeParams;
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index fca7116503..d6dfc20ae8 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -26,8 +26,8 @@ limitations under the License.
// TfLiteRegistration - the implementation of a conceptual operation.
//
// Some abstractions in this file are created and managed by Interpreter.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#include <stdint.h>
#include <stdlib.h>
@@ -296,4 +296,4 @@ typedef struct {
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh
index 362e5bee25..e1b7b3613a 100755
--- a/tensorflow/contrib/lite/download_dependencies.sh
+++ b/tensorflow/contrib/lite/download_dependencies.sh
@@ -22,7 +22,14 @@ cd "$SCRIPT_DIR/../../.."
DOWNLOADS_DIR=tensorflow/contrib/lite/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl
-EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
+# Ensure it is being run from repo root
+if [ ! -f $BZL_FILE_PATH ]; then
+ echo "Could not find ${BZL_FILE_PATH}":
+ echo "Likely you are not running this from the root directory of the repository.";
+ exit 1;
+fi
+
+EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"
GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index d5715e4f90..da193d2586 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
#include <cstdarg>
#include "tensorflow/contrib/lite/context.h"
@@ -51,4 +51,4 @@ ErrorReporter* DefaultErrorReporter();
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index a885a57b65..0ab7aa25d0 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -29,13 +29,6 @@
#include "ios_image_load.h"
-#define LOG(x) std::cerr
-#define CHECK(x) \
- if (!(x)) { \
- LOG(ERROR) << #x << "failed"; \
- exit(1); \
- }
-
NSString* RunInferenceOnImage();
@interface RunModelViewController ()
@@ -89,8 +82,8 @@ static void GetTopN(const float* prediction, const int prediction_size, const in
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
- LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String]
- << "' in bundle.";
+ NSLog(@"Couldn't find '%@.%@' in bundle.", name, extension);
+ exit(-1);
}
return file_path;
}
@@ -106,11 +99,12 @@ NSString* RunInferenceOnImage() {
std::unique_ptr<tflite::FlatBufferModel> model(
tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
if (!model) {
- LOG(FATAL) << "Failed to mmap model " << [graph UTF8String];
+ NSLog(@"Failed to mmap model %@.", graph);
+ exit(-1);
}
- LOG(INFO) << "Loaded model " << [graph UTF8String];
+ NSLog(@"Loaded model %@.", graph);
model->error_reporter();
- LOG(INFO) << "resolved reporter";
+ NSLog(@"Resolved reporter.");
#ifdef TFLITE_CUSTOM_OPS_HEADER
tflite::MutableOpResolver resolver;
@@ -122,7 +116,8 @@ NSString* RunInferenceOnImage() {
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
- LOG(FATAL) << "Failed to construct interpreter";
+ NSLog(@"Failed to construct interpreter.");
+ exit(-1);
}
if (num_threads != -1) {
@@ -136,7 +131,8 @@ NSString* RunInferenceOnImage() {
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
- LOG(FATAL) << "Failed to allocate tensors!";
+ NSLog(@"Failed to allocate tensors.");
+ exit(-1);
}
// Read the label list
@@ -181,7 +177,8 @@ NSString* RunInferenceOnImage() {
}
if (interpreter->Invoke() != kTfLiteOk) {
- LOG(FATAL) << "Failed to invoke!";
+ NSLog(@"Failed to invoke!");
+ exit(-1);
}
float* output = interpreter->typed_output_tensor<float>(0);
@@ -211,11 +208,9 @@ NSString* RunInferenceOnImage() {
ss << "\n";
}
- LOG(INFO) << "Predictions: " << ss.str();
-
std::string predictions = ss.str();
NSString* result = @"";
result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()];
-
+ NSLog(@"Predictions: %@", result);
return result;
}
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 4d2e1ce0bc..d7f49ad875 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -148,14 +148,22 @@ void RunInference(Settings* s) {
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
- if (s->input_floating) {
- downsize<float>(interpreter->typed_tensor<float>(input), in, image_height,
- image_width, image_channels, wanted_height, wanted_width,
- wanted_channels, s);
- } else {
- downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
- image_height, image_width, image_channels, wanted_height,
- wanted_width, wanted_channels, s);
+ switch (interpreter->tensor(input)->type) {
+ case kTfLiteFloat32:
+ s->input_floating = true;
+ downsize<float>(interpreter->typed_tensor<float>(input), in,
+ image_height, image_width, image_channels,
+ wanted_height, wanted_width, wanted_channels, s);
+ break;
+ case kTfLiteUInt8:
+ downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
+ image_height, image_width, image_channels,
+ wanted_height, wanted_width, wanted_channels, s);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle input type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
struct timeval start_time, stop_time;
@@ -177,13 +185,22 @@ void RunInference(Settings* s) {
std::vector<std::pair<float, int>> top_results;
- if (s->input_floating) {
- get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
- num_results, threshold, &top_results, s->input_floating);
- } else {
- get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
+ int output = interpreter->outputs()[0];
+ switch (interpreter->tensor(output)->type) {
+ case kTfLiteFloat32:
+ get_top_n<float>(interpreter->typed_output_tensor<float>(0),
output_size, num_results, threshold, &top_results,
- s->input_floating);
+ true);
+ break;
+ case kTfLiteUInt8:
+ get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
+ output_size, num_results, threshold, &top_results,
+ false);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle output type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
std::vector<string> labels;
@@ -203,13 +220,11 @@ void display_usage() {
LOG(INFO) << "label_image\n"
<< "--accelerated, -a: [0|1], use Android NNAPI or note\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
- << "--input_floating, -f: [0|1] type of input layer is floating "
- "point numbers\n"
<< "--input_mean, -b: input mean\n"
<< "--input_std, -s: input standard deviation\n"
<< "--image, -i: image_name.bmp\n"
<< "--labels, -l: labels for the model\n"
- << "--tflite_mode, -m: model_name.tflite\n"
+ << "--tflite_model, -m: model_name.tflite\n"
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "\n";
@@ -223,7 +238,6 @@ int Main(int argc, char** argv) {
static struct option long_options[] = {
{"accelerated", required_argument, 0, 'a'},
{"count", required_argument, 0, 'c'},
- {"input_floating", required_argument, 0, 'f'},
{"verbose", required_argument, 0, 'v'},
{"image", required_argument, 0, 'i'},
{"labels", required_argument, 0, 'l'},
@@ -254,11 +268,6 @@ int Main(int argc, char** argv) {
s.loop_count = strtol( // NOLINT(runtime/deprecated_fn)
optarg, (char**)NULL, 10);
break;
- case 'f':
- s.input_floating = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
- s.input_layer_type = "float";
- break;
case 'i':
s.input_bmp_name = optarg;
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index ce98e06fc1..4de32e33fb 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -16,9 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
-#include <string>
#include "tensorflow/contrib/lite/string.h"
+namespace tflite {
+namespace label_image {
+
struct Settings {
bool verbose = false;
bool accel = false;
@@ -33,4 +35,7 @@ struct Settings {
int number_of_threads = 4;
};
+} // namespace label_image
+} // namespace tflite
+
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md
index d6019d673f..9ce32cf101 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.md
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.md
@@ -1,8 +1,12 @@
label_image for TensorFlow Lite inspired by TensorFlow's label_image.
+
+To build label_image for Android, run $TENSORFLOW_ROOT/configure
+and set Android NDK or configure NDK setting in
+$TENSORFLOW_ROOT/WORKSPACE first.
To build it for android ARMv8:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=arm64-v8a \
@@ -10,13 +14,13 @@ To build it for android ARMv8:
```
or
```
-> bazel build --config android_arm64 --cxxopt=-std=c++11 \
+> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```
To build it for android arm-v7a:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=armeabi-v7a \
@@ -24,7 +28,7 @@ To build it for android arm-v7a:
```
or
```
-> bazel build --config android_arm --cxxopt=-std=c++11 \
+> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 5481aede60..57690058c4 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
+#define TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
#include <vector>
@@ -50,4 +50,4 @@ class GraphInfo {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
+#endif // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5f5981e45a..69a597dc5a 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -291,7 +291,6 @@ TfLiteStatus Interpreter::Invoke() {
TfLiteStatus status = kTfLiteOk;
if (nnapi_delegate_) {
- TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
if (next_node_to_prepare_ == nodes_and_registration_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 38dd402e8a..4f732769f9 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Main abstraction controlling the tflite interpreter.
// See context.h for the API for defining operations (TfLiteRegistration).
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
#include <cstdio>
#include <cstdlib>
@@ -363,4 +363,4 @@ class Interpreter {
};
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 7e9644f36c..4195e7553c 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -50,7 +50,7 @@ cc_library(
deps = [
":op_macros",
"//tensorflow/contrib/lite:context",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
],
)
@@ -103,9 +103,11 @@ cc_library(
"space_to_batch_nd.cc",
"space_to_depth.cc",
"squeeze.cc",
+ "strided_slice.cc",
"sub.cc",
"svdf.cc",
"transpose.cc",
+ "unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
],
hdrs = [
@@ -250,6 +252,18 @@ tf_cc_test(
)
tf_cc_test(
+ name = "unidirectional_sequence_lstm_test",
+ size = "small",
+ srcs = ["unidirectional_sequence_lstm_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "unidirectional_sequence_rnn_test",
size = "small",
srcs = ["unidirectional_sequence_rnn_test.cc"],
@@ -505,6 +519,18 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "strided_slice_test",
+ size = "small",
+ srcs = ["strided_slice_test.cc"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index cfb3369e99..41ec3cca33 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
#include <algorithm>
#include <cmath>
@@ -55,4 +55,4 @@ class ActivationFunctor {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index c75c04baea..37f499a4d0 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -322,22 +322,23 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
- const float* filter_data;
- if (data->need_hwcn_weights) {
- filter_data = GetTensorData<float>(hwcn_weights);
- } else {
- filter_data = GetTensorData<float>(filter);
- }
-
if (kernel_type == kReference) {
- reference_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ reference_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(filter), GetTensorDims(filter),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
} else {
+ const float* filter_data;
+ if (data->need_hwcn_weights) {
+ filter_data = GetTensorData<float>(hwcn_weights);
+ } else {
+ filter_data = GetTensorData<float>(filter);
+ }
+
multithreaded_ops::Conv(
GetTensorData<float>(input), GetTensorDims(input), filter_data,
GetTensorDims(filter), GetTensorData<float>(bias), GetTensorDims(bias),
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index f8df797daf..0e4187d1ea 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
// Check that input and output types match.
TF_LITE_ENSURE_EQ(context, input->type, output->type);
- // TODO(mgubin): only 1D positions are currently supported.
- TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1);
+ // TODO(mgubin): only 0D or 1D positions are currently supported.
+ TF_LITE_ENSURE(context, NumDimensions(positions) <= 1);
// TODO(mgubin): Only default axis == 0 is supported.
+ TF_LITE_ENSURE_EQ(context, params->axis, 0);
// Check conditions for different types.
switch (input->type) {
case kTfLiteFloat32:
@@ -64,7 +65,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
const int num_dimensions =
NumDimensions(input) + NumDimensions(positions) - 1;
- TF_LITE_ENSURE(context, params->axis < num_dimensions);
+ TF_LITE_ENSURE(context, params->axis <= num_dimensions);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
int output_index = 0;
for (int i = 0; i < params->axis; ++i) {
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 6343d3b4ef..658d977b8d 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -48,8 +48,8 @@ class GatherOpModel : public SingleOpModel {
PopulateStringTensor(input_, data);
}
- void SetPositions(std::initializer_list<int32> data) {
- PopulateTensor<int32>(positions_, data);
+ void SetPositions(std::initializer_list<int> data) {
+ PopulateTensor<int>(positions_, data);
}
std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
@@ -76,6 +76,29 @@ TEST(GatherOpTest, Shuffle) {
ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2})));
}
+TEST(GatherOpTest, Test0DIndex) {
+ GatherOpModel m({2, 2}, TensorType_FLOAT32, {});
+ m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});
+ m.SetPositions({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray(ArrayFloatNear({0.7, 0.8})));
+ EXPECT_THAT(m.GetOutputShape(),
+ ElementsAreArray({2}));
+}
+
+TEST(GatherOpTest, Test0DIndexWith0DResult) {
+ // 0D tensor is special case in current TFLite. Test it once to make sure
+ // existing workarounds are fine with it.
+ GatherOpModel m({3}, TensorType_FLOAT32, {});
+ m.SetInputFloat({1.0, 2.0, 3.0});
+ m.SetPositions({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray(ArrayFloatNear({2.0})));
+ EXPECT_TRUE(m.GetOutputShape().empty());
+}
+
TEST(FloatGatherOpTest, Duplicate) {
GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2});
m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index b531959ffb..466781cbce 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/context.h"
@@ -51,4 +51,4 @@ void SetMaxNumThreads(TfLiteContext* context, int num_threads);
} // namespace gemm_support
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a3ecb2ebf6..38b032c6de 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -145,7 +145,7 @@ cc_library(
":types",
":round",
"//third_party/eigen3",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
"//tensorflow/contrib/lite:builtin_op_data",
] + select({
":haswell": tflite_deps_intel,
@@ -223,7 +223,7 @@ cc_library(
":round",
":types",
"//third_party/eigen3",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
"//tensorflow/contrib/lite:builtin_op_data",
] + select({
":haswell": tflite_deps_intel,
@@ -267,6 +267,8 @@ cc_library(
"optimized/neon_tensor_utils.cc",
],
hdrs = [
+ "common.h",
+ "optimized/cpu_check.h",
"optimized/neon_tensor_utils.h",
"optimized/tensor_utils_impl.h",
],
@@ -274,8 +276,11 @@ cc_library(
deps = [
":cpu_check",
":portable_tensor_utils",
+ ":types",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite/kernels:activation_functor",
+ "@arm_neon_2_x86_sse",
+ "@gemmlowp//:gemmlowp",
],
)
@@ -285,14 +290,21 @@ cc_library(
"tensor_utils.cc",
],
hdrs = [
+ "common.h",
+ "compatibility.h",
+ "optimized/cpu_check.h",
+ "optimized/neon_tensor_utils.h",
"optimized/tensor_utils_impl.h",
"reference/portable_tensor_utils.h",
"tensor_utils.h",
+ "types.h",
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite:builtin_op_data",
+ "@arm_neon_2_x86_sse",
+ "@gemmlowp//:gemmlowp",
] + select({
":arm": [
":neon_tensor_utils",
@@ -312,6 +324,15 @@ cc_library(
":ios_arm64": [
":neon_tensor_utils",
],
+ ":x86_64": [
+ ":neon_tensor_utils",
+ ],
+ ":x86": [
+ ":neon_tensor_utils",
+ ],
+ ":darwin": [
+ ":neon_tensor_utils",
+ ],
"//conditions:default": [
":portable_tensor_utils",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index 28f19a2506..fdeacedace 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
@@ -104,4 +104,4 @@ inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index 796a03566a..1d963afb7e 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
#include <cassert>
#include <cstdint>
@@ -75,4 +75,4 @@ using uint16 = std::uint16_t;
using int32 = std::int32_t;
using uint32 = std::uint32_t;
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
index dea46cc120..629783d7e5 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -34,7 +34,7 @@ inline bool TestCPUFeatureNeon() {
#endif // __aarch64__
}
-#elif __ARM_NEON
+#elif defined USE_NEON || defined __ARM_NEON
inline bool TestCPUFeatureNeon() {
return true;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index da34c8aef9..81796e295d 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
@@ -1057,4 +1057,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
} // namespace optimized_ops
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 051ed2a2c4..fc58978964 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
@@ -1504,7 +1504,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "*\n"
<< "* If you would like to carry on with the slow code, compile\n"
<< "* with this preprocessor token defined:\n"
- << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
<< "*\n"
<< "* The right thing to do, if you care about performance, is to add\n"
<< "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
@@ -1913,4 +1913,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
} // namespace optimized_ops
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
index 8004c24a99..f21fbf532a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h
@@ -16,8 +16,8 @@ limitations under the License.
// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h.
// TODO(petewarden) - move this to a common location in Eigen itself.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
#define EIGEN_USE_CUSTOM_THREAD_POOL
#define EIGEN_USE_THREADS
@@ -228,4 +228,4 @@ EIGEN_DEVICE_FUNC
// clang-format on
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
index 7f78f69360..d85e06a5d5 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_
#define EIGEN_USE_CUSTOM_THREAD_POOL
#define EIGEN_USE_THREADS
@@ -140,4 +140,4 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
index 1d5c316194..d34708b8fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h
@@ -19,8 +19,8 @@ limitations under the License.
// clang-format off
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
#include "Eigen/Core"
@@ -164,4 +164,4 @@ typedef unsigned __int64 uint64_t;
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index b3615f4658..0bfb4e9b1f 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
#include <assert.h>
#include <stdint.h>
@@ -192,4 +192,4 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
} // namespace multithreaded_ops
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index bf0bdfb1fb..ea8502ae33 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,12 +15,12 @@ limitations under the License.
#include <string.h>
#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
#ifdef USE_NEON
-#include <arm_neon.h>
#define kFloatWeightsPerNeonLane 4
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 3a4af87304..b7e317dc60 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
@@ -110,4 +110,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
} // namespace tensor_utils
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 1b1f455855..8163c76cfd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
#include <assert.h>
#include <stdint.h>
@@ -1538,9 +1538,10 @@ void Add(const int32* input1_data, const Dims<4>& input1_dims,
// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
-template <FusedActivationFunctionType Ac, typename T>
+template <typename T>
void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
T* output_data, const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
@@ -1563,15 +1564,30 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ output_activation_min, output_activation_max);
}
}
}
}
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
inline void BroadcastAdd(int left_shift, const uint8* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
int32 input1_multiplier, int input1_shift,
@@ -1772,9 +1788,10 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims,
// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
-template <FusedActivationFunctionType Ac, typename T>
+template <typename T>
void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
T* output_data, const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastMul");
@@ -1797,15 +1814,30 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ output_activation_min, output_activation_max);
}
}
}
}
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int32 input1_offset, const uint8* input2_data,
const Dims<4>& input2_dims, int32 input2_offset,
@@ -3805,4 +3837,4 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
#pragma GCC diagnostic pop
#endif
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 8e0f234545..9aabee5000 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -112,4 +112,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
} // end namespace reference_ops
} // end namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index 8a80558b32..e9b6baeaee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
#include <algorithm>
@@ -135,4 +135,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
} // end namespace reference_ops
} // end namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 7f90d731b8..afc3e26e79 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
// TDOD(ghodrat): Remove this header file and the dependency to internal data
// structure.
@@ -186,4 +186,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
} // namespace tensor_utils
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 1d86183d94..31bade26f9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
#include <stdint.h>
#include <sys/types.h>
@@ -889,10 +889,11 @@ inline void Add(int left_shift, const uint8* input1_data,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <FusedActivationFunctionType Ac>
-void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
NdArrayDesc<4> desc1;
@@ -914,15 +915,30 @@ void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims,
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ output_activation_min, output_activation_max);
}
}
}
}
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
inline void BroadcastAdd(int left_shift, const uint8* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
int32 input1_multiplier, int input1_shift,
@@ -1053,10 +1069,11 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims,
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
-template <FusedActivationFunctionType Ac>
-void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("BroadcastMul");
NdArrayDesc<4> desc1;
@@ -1078,15 +1095,30 @@ void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ output_activation_min, output_activation_max);
}
}
}
}
}
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int32 input1_offset, const uint8* input2_data,
const Dims<4>& input2_dims, int32 input2_offset,
@@ -2330,6 +2362,18 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
}
}
+inline bool LoopCondition(int index, int stop, int stride) {
+ return stride > 0 ? index < stop : index > stop;
+}
+
+inline int StartIndex(int start, int stride, int dim, bool masked) {
+ return masked ? (stride > 0 ? 0 : dim - 1) : start;
+}
+
+inline int StopIndex(int stop, int stride, int dim, bool masked) {
+ return masked ? (stride > 0 ? dim : -1) : stop;
+}
+
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
int begin_mask, int end_mask,
@@ -2337,20 +2381,35 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& stops,
const std::vector<int>& strides, T* output_data,
const Dims<4>& output_dims) {
- const int start_b = (begin_mask & 8) ? 0 : starts[3];
- const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
- const int start_h = (begin_mask & 4) ? 0 : starts[2];
- const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
- const int start_w = (begin_mask & 2) ? 0 : starts[1];
- const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
- const int start_d = (begin_mask & 1) ? 0 : starts[0];
- const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
+ TFLITE_DCHECK_EQ(starts.size(), 4);
+ TFLITE_DCHECK_EQ(stops.size(), 4);
+ TFLITE_DCHECK_EQ(strides.size(), 4);
+ const int start_b =
+ StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8);
+ const int stop_b =
+ StopIndex(stops[3], strides[3], input_dims.sizes[3], end_mask & 8);
+ const int start_h =
+ StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4);
+ const int stop_h =
+ StopIndex(stops[2], strides[2], input_dims.sizes[2], end_mask & 4);
+ const int start_w =
+ StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2);
+ const int stop_w =
+ StopIndex(stops[1], strides[1], input_dims.sizes[1], end_mask & 2);
+ const int start_d =
+ StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1);
+ const int stop_d =
+ StopIndex(stops[0], strides[0], input_dims.sizes[0], end_mask & 1);
T* out_ptr = output_data;
- for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
- for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
- for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
- for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
+ for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]);
+ in_b += strides[3]) {
+ for (int in_h = start_h; LoopCondition(in_h, stop_h, strides[2]);
+ in_h += strides[2]) {
+ for (int in_w = start_w; LoopCondition(in_w, stop_w, strides[1]);
+ in_w += strides[1]) {
+ for (int in_d = start_d; LoopCondition(in_d, stop_d, strides[0]);
+ in_d += strides[0]) {
*out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
}
}
@@ -2628,4 +2687,4 @@ void Transpose(const T* input, const Dims<4>& input_dims, T* output,
} // namespace reference_ops
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h
index 38525b0e20..f299d0bd87 100644
--- a/tensorflow/contrib/lite/kernels/internal/round.h
+++ b/tensorflow/contrib/lite/kernels/internal/round.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
#include <cmath>
@@ -36,4 +36,4 @@ inline T TfLiteRound(const T x) {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 1961e1a2d5..dfe76c2afd 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
#include <vector>
#include "tensorflow/contrib/lite/context.h"
@@ -83,4 +83,4 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
index 904a97803a..f4181b18a8 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#ifndef USE_NEON
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index e7e2994397..40d144979b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#include "tensorflow/contrib/lite/builtin_op_data.h"
@@ -113,4 +113,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
} // namespace tensor_utils
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 5989ac8fcd..afe131b06e 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -134,4 +134,4 @@ bool IsPackedWithoutStrides(const Dims<N>& dims) {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index 25556ae456..bfdfba00f5 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
@@ -44,6 +44,22 @@ inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
return nullptr;
}
+// Determines whether tensor is constant.
+inline bool IsConstantTensor(TfLiteTensor* tensor) {
+ return tensor->allocation_type == kTfLiteMmapRo;
+}
+
+// Determines whether tensor is dynamic. Note that a tensor can be non-const and
+// not dynamic. This function specificially checks for a dynamic tensor.
+inline bool IsDynamicTensor(TfLiteTensor* tensor) {
+ return tensor->allocation_type == kTfLiteDynamic;
+}
+
+// Sets tensor to dynamic.
+inline void SetTensorToDynamic(TfLiteTensor* tensor) {
+ tensor->allocation_type = kTfLiteDynamic;
+}
+
// Calculates the multiplication factor for a quantized convolution (or
// quantized depthwise convolution) involving the given tensors. Returns an
// error if the scales of the tensors are not compatible.
@@ -62,4 +78,4 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
index 63670efcb1..7568eaa88e 100644
--- a/tensorflow/contrib/lite/kernels/op_macros.h
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
#include <cstdio>
@@ -31,4 +31,4 @@ limitations under the License.
if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
} while (0)
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 1a0d9d1505..4003ed10df 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -33,65 +33,92 @@ enum KernelType {
kGenericOptimized,
};
-// TODO(nupurgarg): Padding represented as a tensor is ignored. Only use the
-// `left_padding` and `right_padding` specified in `params`.
struct PadContext {
PadContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLitePadParams*>(node->builtin_data);
input = GetInput(context, node, 0);
+ paddings = GetInput(context, node, 1);
output = GetOutput(context, node, 0);
+ dims = NumDimensions(input);
}
- TfLitePadParams* params;
TfLiteTensor* input;
+ TfLiteTensor* paddings;
TfLiteTensor* output;
+ int dims;
};
+// Resizes output array based on the input size and padding size. This function
+// is callable from both Prepare() and Eval() as long as the caller ensures the
+// paddings data is present.
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ PadContext* op_context) {
+ // Ensures the paddings array is dims x 2.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
+ op_context->dims);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
+
+ // Determines the size of the output tensor.
+ TfLiteIntArray* input_size = op_context->input->dims;
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
+ const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
+
+ for (int idx = 0; idx < op_context->dims; ++idx) {
+ int before_padding = *paddings_data++;
+ int after_padding = *paddings_data++;
+
+ TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
+ "Pad value has to be greater than equal to 0.");
+
+ output_size->data[idx] =
+ (input_size->data[idx] + before_padding + after_padding);
+ }
+
+ return context->ResizeTensor(context, op_context->output, output_size);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- // Determines size of output tensor.
PadContext op_context(context, node);
- int dims = NumDimensions(op_context.input);
- TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
// TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, dims, 4);
-
- const TfLiteIntArray* input_size = op_context.input->dims;
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims);
- for (int idx = 0; idx < dims; ++idx) {
- TF_LITE_ENSURE_MSG(context,
- (op_context.params->before_padding[idx] >= 0 &&
- op_context.params->after_padding[idx] >= 0),
- "Pad value has to be greater than equal to 0.");
- output_size->data[idx] =
- (input_size->data[idx] + op_context.params->before_padding[idx] +
- op_context.params->after_padding[idx]);
- }
+ TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
- return context->ResizeTensor(context, op_context.output, output_size);
+ // Exit early if paddings is a non-const tensor. Set output tensor to
+ // dynamic so output size can be determined in Eval.
+ if (!IsConstantTensor(op_context.paddings)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, &op_context);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
- std::vector<int> before_padding(
- op_context.params->before_padding,
- op_context.params->before_padding + op_context.params->num_dimensions);
- std::vector<int> after_padding(
- op_context.params->after_padding,
- op_context.params->after_padding + op_context.params->num_dimensions);
-
- // TODO(nupurgarg): Change TOCO's implementation to use padding arrays
- // in forward order (depth, width, height, batch).
- // Converts from int[] = {depth, width, height, batch} to int[] = {batch,
- // height, width, depth} to match TOCO's implementation of pad in
- // referenced_ops.h and optimized_ops.h.
- std::reverse(before_padding.begin(), before_padding.end());
- std::reverse(after_padding.begin(), after_padding.end());
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+ }
+
+ // TODO(nupurgarg): Change kernel implementation to take in int* instead of
+ // vector<int> to remove malloc from Eval().
+ // Create before and after padding arrays that are accepted by the kernel.
+ std::vector<int> before_padding;
+ std::vector<int> after_padding;
+ const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
+
+ // TODO(nupurgarg): Change kernel implementation to use padding arrays in
+ // forward order (depth, width, height, batch).
+ // Build paddings in order of int[] = {batch, height, width, depth} to match
+ // kernel implementation of Pad in referenced_ops.h and optimized_ops.h.
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ before_padding.push_back(paddings_data[idx * 2]);
+ after_padding.push_back(paddings_data[idx * 2 + 1]);
+ }
#define TF_LITE_PAD(type, scalar) \
type::Pad(GetTensorData<scalar>(op_context.input), \
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f3ea9417df..28834ad071 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -25,52 +25,87 @@ using ::testing::ElementsAreArray;
class PadOpModel : public SingleOpModel {
public:
- PadOpModel(std::initializer_list<int> input_shape,
- std::initializer_list<int> before_padding,
- std::initializer_list<int> after_padding) {
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_PAD, BuiltinOptions_PadOptions,
- CreatePadOptions(builder_, builder_.CreateVector<int>(before_padding),
- builder_.CreateVector<int>(after_padding))
- .Union());
- BuildInterpreter({input_shape});
- }
-
void SetInput(std::initializer_list<float> data) {
PopulateTensor<float>(input_, data);
}
+ void SetPaddings(std::initializer_list<int> paddings) {
+ PopulateTensor<int>(paddings_, paddings);
+ }
+
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int output_;
+ int paddings_;
+};
+
+// Tests case where paddings is a const tensor.
+//
+// Example usage is as follows:
+// PadOpDynamicModel m(input_shape, paddings_shape, paddings_data);
+// m.SetInput(input_data);
+// m.Invoke();
+class PadOpConstModel : public PadOpModel {
+ public:
+ PadOpConstModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> paddings_shape,
+ std::initializer_list<int> paddings) {
+ input_ = AddInput(TensorType_FLOAT32);
+ paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
+ CreatePadOptions(builder_).Union());
+ BuildInterpreter({input_shape});
+ }
+};
+
+// Test case where paddings is a non-const tensor.
+//
+// Example usage is as follows:
+// PadOpDynamicModel m(input_shape, paddings_shape);
+// m.SetInput(input_data);
+// m.SetPaddings(paddings_data);
+// m.Invoke();
+class PadOpDynamicModel : public PadOpModel {
+ public:
+ PadOpDynamicModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> paddings_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ paddings_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
+ CreatePadOptions(builder_).Union());
+ BuildInterpreter({input_shape, paddings_shape});
+ }
};
TEST(PadOpTest, TooManyDimensions) {
EXPECT_DEATH(
- PadOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9},
- {1, 2, 3, 4, 5, 6, 7, 8, 9}),
+ PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2},
+ {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}),
"dims != 4");
}
-// TODO(nupurgarg): Test case where before padding and after padding arrays
-// don't contain the same number of dimensions.
TEST(PadOpTest, UnequalDimensions) {
- EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {1, 2, 3}, {1, 2, 3}),
- "dims != op_context.params->num_dimensions");
+ EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}),
+ "3 != 4");
}
TEST(PadOpTest, InvalidPadValue) {
- EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {0, 1, 2, 0}, {0, -1, -1, 0}),
- "Pad value has to be greater than equal to 0.");
+ EXPECT_DEATH(
+ PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}),
+ "Pad value has to be greater than equal to 0.");
}
-TEST(PadOpTest, SimpleTest) {
- PadOpModel m({1, 2, 2, 1}, {0, 1, 1, 0}, {0, 1, 1, 0});
+TEST(PadOpTest, SimpleConstTest) {
+ // Padding is represented as four 2-D lists representing above padding and
+ // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
+ PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0});
m.SetInput({1, 2, 3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
@@ -78,10 +113,30 @@ TEST(PadOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
-TEST(PadOpTest, AdvancedTest) {
- // The padding is input in the order of batch, height, width, depth.
- PadOpModel m({1, 2, 3, 1}, {0, 0, 1, 0}, {0, 2, 3, 0});
+TEST(PadOpTest, SimpleDynamicTest) {
+ PadOpDynamicModel m({1, 2, 2, 1}, {4, 2});
+ m.SetInput({1, 2, 3, 4});
+ m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
+ 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
+TEST(PadOpTest, AdvancedConstTest) {
+ PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0});
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
+}
+
+TEST(PadOpTest, AdvancedDynamicTest) {
+ PadOpDynamicModel m({1, 2, 3, 1}, {4, 2});
m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3a60274524..40b8476b37 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
namespace tflite {
@@ -25,4 +25,4 @@ inline int ComputePadding(int stride, int in_size, int filter_size,
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 45ad5f1890..f605deaa5b 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -48,6 +48,7 @@ TfLiteRegistration* Register_MUL();
TfLiteRegistration* Register_L2_NORMALIZATION();
TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION();
TfLiteRegistration* Register_LSTM();
+TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
TfLiteRegistration* Register_PAD();
TfLiteRegistration* Register_RESHAPE();
TfLiteRegistration* Register_RESIZE_BILINEAR();
@@ -57,6 +58,7 @@ TfLiteRegistration* Register_GATHER();
TfLiteRegistration* Register_TRANSPOSE();
TfLiteRegistration* Register_MEAN();
TfLiteRegistration* Register_SQUEEZE();
+TfLiteRegistration* Register_STRIDED_SLICE();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -89,6 +91,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORMALIZATION());
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ Register_UNIDIRECTIONAL_SEQUENCE_LSTM());
AddBuiltin(BuiltinOperator_PAD, Register_PAD());
AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR());
@@ -100,6 +104,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_DIV, Register_DIV());
AddBuiltin(BuiltinOperator_SUB, Register_SUB());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
+ AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
}
TfLiteRegistration* BuiltinOpResolver::FindOp(
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 28f5e0fcc8..b9cff0ae21 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
#include "tensorflow/contrib/lite/context.h"
@@ -47,4 +47,4 @@ class BuiltinOpResolver : public OpResolver {
} // namespace ops
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 1613c9a89f..9a419af023 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -33,49 +33,53 @@ enum KernelType {
};
constexpr int kInputTensor = 0;
+constexpr int kSizeTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params =
- reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
-
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* size = GetInput(context, node, kSizeTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// TODO(ahentz): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
// TODO(ahentz): Our current implementations only support float32.
- TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
- TF_LITE_ENSURE_EQ(context, input->type, output->type);
-
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
- output_size->data[0] = input->dims->data[0];
- output_size->data[1] = params->new_height;
- output_size->data[2] = params->new_width;
- output_size->data[3] = input->dims->data[3];
-
- return context->ResizeTensor(context, output, output_size);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
+ // ResizeBilinear creates a float tensor even when the input is made of
+ // integers.
+ output->type = kTfLiteFloat32;
+
+ // TODO(ahentz): if the input is constant, we can allocate here.
+ output->allocation_type = kTfLiteDynamic;
+ return kTfLiteOk;
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params =
- reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
-
TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* size = GetInput(context, node, kSizeTensor);
- // We have to fake a tensor here, to satisfy ResizeBilinear().
- int32 output_size_data[2] = {params->new_height, params->new_width};
+ // TODO(ahentz): we only need to do this here if it wasn't done in Eval().
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ const int32* size_data = GetTensorData<int32>(size);
+ output_size->data[1] = size_data[0];
+ output_size->data[2] = size_data[1];
+ output_size->data[3] = input->dims->data[3];
+ context->ResizeTensor(context, output, output_size);
+ TfLiteTensorRealloc(output->bytes, output);
if (output->type == kTfLiteFloat32) {
#define TF_LITE_RESIZE_BILINEAR(type) \
type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
- output_size_data, GetTensorDims({1, 1, 1, 2}), \
+ GetTensorData<int32>(size), GetTensorDims(size), \
GetTensorData<float>(output), GetTensorDims(output))
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
index 314a71e210..2b1aaf654f 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -25,47 +25,52 @@ using ::testing::ElementsAreArray;
class ResizeBilinearOpModel : public SingleOpModel {
public:
- ResizeBilinearOpModel(std::initializer_list<int> input_shape, int new_height,
- int new_width) {
+ ResizeBilinearOpModel(std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
+ size_ = AddInput(TensorType_INT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions,
- CreateResizeBilinearOptions(builder_, new_height, new_width).Union());
- BuildInterpreter({input_shape});
+ SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
+ BuiltinOptions_ResizeBilinearOptions,
+ CreateResizeBilinearOptions(builder_).Union());
+ BuildInterpreter({input_shape, {2}});
}
void SetInput(std::initializer_list<float> data) {
PopulateTensor(input_, data);
}
+ void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
private:
int input_;
+ int size_;
int output_;
};
TEST(ResizeBilinearOpTest, HorizontalResize) {
- ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3);
+ ResizeBilinearOpModel m({1, 1, 2, 1});
m.SetInput({3, 6});
+ m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, VerticalResize) {
- ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1);
+ ResizeBilinearOpModel m({1, 2, 1, 1});
m.SetInput({3, 9});
+ m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
- ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3);
+ ResizeBilinearOpModel m({1, 2, 2, 1});
m.SetInput({
3, 6, //
9, 12 //
});
+ m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
@@ -75,13 +80,14 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
- ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3);
+ ResizeBilinearOpModel m({2, 2, 2, 1});
m.SetInput({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
+ m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3, 5, 6, //
@@ -94,11 +100,12 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
- ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3);
+ ResizeBilinearOpModel m({1, 2, 2, 2});
m.SetInput({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
+ m.SetSize({3, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
3, 4, 5, 8, 6, 10, //
diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc
index 409227b626..a8aab88357 100644
--- a/tensorflow/contrib/lite/kernels/squeeze_test.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc
@@ -22,6 +22,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using ::testing::IsEmpty;
class BaseSqueezeOpModel : public SingleOpModel {
public:
@@ -103,6 +104,16 @@ TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+TEST(FloatSqueezeOpTest, SqueezeAllDims) {
+ std::initializer_list<float> data = {3.85};
+ FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 1, 1, 1, 1, 1, 1}},
+ {TensorType_FLOAT32, {1}}, {});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.85}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
new file mode 100644
index 0000000000..91ba4a9b78
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -0,0 +1,256 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <cmath>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace strided_slice {
+
+enum KernelType {
+ kReference,
+ // TODO(soroosh): add kGenericOptimized
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kBeginTensor = 1;
+constexpr int kEndTensor = 2;
+constexpr int kStridesTensor = 3;
+constexpr int kOutputTensor = 0;
+
+struct StridedSliceContext {
+ StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
+ input = GetInput(context, node, kInputTensor);
+ begin = GetInput(context, node, kBeginTensor);
+ end = GetInput(context, node, kEndTensor);
+ strides = GetInput(context, node, kStridesTensor);
+ output = GetOutput(context, node, kOutputTensor);
+ dims = NumDimensions(input);
+ }
+ TfLiteStridedSliceParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* begin;
+ TfLiteTensor* end;
+ TfLiteTensor* strides;
+ TfLiteTensor* output;
+ int dims;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ StridedSliceContext op_context(context, node);
+
+ // Ensure validity of input tensor and its dimension
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ // Only INT32 begin/end/strides are supported
+ // TODO(soroosh) add support for INT64
+ TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
+ TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
+ "StridedSlice op only supports 1D-4D input arrays.");
+
+ // TODO(soroosh): add the following missing functionalities
+ TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
+ "ellipsis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
+ "new_axis_mask is not implemented yet.");
+ TF_LITE_ENSURE_MSG(context, op_context.params->shrink_axis_mask == 0,
+ "shrink_axis_mask is not implemented yet.");
+
+ // TODO(soroosh): optimize for constant tensors to do allocation in Prepare
+ op_context.output->allocation_type = kTfLiteDynamic;
+ return kTfLiteOk;
+} // namespace strided_slice
+
+// TODO(soroosh): consolidate with BytesRequired in interpreter.h
+TfLiteStatus BytesRequired(TfLiteContext* context, TfLiteType type,
+ const int* dims, int dims_size, size_t* bytes) {
+ // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
+ // MultiplyWithoutOverflow.
+ TF_LITE_ENSURE(context, bytes != nullptr);
+ size_t count = 1;
+ for (int k = 0; k < dims_size; k++) count *= dims[k];
+ switch (type) {
+ case kTfLiteFloat32:
+ *bytes = sizeof(float) * count;
+ break;
+ case kTfLiteInt32:
+ *bytes = sizeof(int32_t) * count;
+ break;
+ case kTfLiteUInt8:
+ *bytes = sizeof(uint8_t) * count;
+ break;
+ case kTfLiteInt64:
+ *bytes = sizeof(int64_t) * count;
+ break;
+ default:
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Reverse order of bits in the mask to match the expected order in kernel
+inline int ReverseMaskBits(int mask, int num_dimensions) {
+ int out = 0;
+ for (int dim = 0; dim < num_dimensions; dim++) {
+ out <<= 1;
+ out += (mask & 1);
+ mask >>= 1;
+ }
+ return out;
+}
+
+// This Op only supports 1-4D cases and since we use the reference 4D
+// implementation, the 1-3D tensors are mapped to 4D.
+const int kMaxDim = 4;
+
+inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
+ return (divisor + (dividend % divisor)) % divisor;
+}
+
+inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
+ return pos_stride
+ ? (index >= dim ? dim
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim), dim))
+ : (index < -dim
+ ? -1
+ : PositiveRemainder(
+ std::min(std::max(index, -dim), dim - 1), dim));
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ StridedSliceContext op_context(context, node);
+
+ std::vector<int> starts;
+ std::vector<int> stops;
+ std::vector<int> strides;
+
+ // Determine size of output tensor and map indices
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(op_context.dims);
+ for (int idx = op_context.dims - 1; idx >= 0; --idx) {
+ int dim = op_context.input->dims->data[idx];
+ int32_t stride = GetTensorData<int32_t>(op_context.strides)[idx];
+ TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
+ bool pos_stride = stride > 0;
+
+ int32_t begin =
+ op_context.params->begin_mask & (1 << idx)
+ ? pos_stride ? 0 : dim - 1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.begin)[idx], dim,
+ pos_stride);
+ int32_t end =
+ op_context.params->end_mask & (1 << idx)
+ ? pos_stride ? dim : -1
+ : ClampedIndex(GetTensorData<int32_t>(op_context.end)[idx], dim,
+ pos_stride);
+
+ // This is valid for both positive and negative strides
+ output_shape->data[idx] = ceil((end - begin) / static_cast<float>(stride));
+ output_shape->data[idx] =
+ output_shape->data[idx] < 0 ? 0 : output_shape->data[idx];
+ starts.emplace_back(begin);
+ stops.emplace_back(end);
+ strides.emplace_back(stride);
+ }
+
+ for (int i = op_context.dims; i < kMaxDim; i++) {
+ starts.emplace_back(0);
+ stops.emplace_back(1);
+ strides.emplace_back(1);
+ }
+
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, op_context.output, output_shape));
+
+ size_t required_bytes;
+ TF_LITE_ENSURE_OK(
+ context,
+ BytesRequired(context, op_context.output->type, output_shape->data,
+ output_shape->size, &required_bytes));
+ TfLiteTensorRealloc(required_bytes, op_context.output);
+
+ op_context.params->begin_mask =
+ ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
+ op_context.params->end_mask =
+ ReverseMaskBits(op_context.params->end_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), op_context.params->begin_mask, \
+ op_context.params->end_mask, starts, stops, strides, \
+ GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
+
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, float);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported "
+ "by StridedSlice.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_STRIDED_SLICE
+ return kTfLiteOk;
+}
+
+} // namespace strided_slice
+
+TfLiteRegistration* Register_STRIDED_SLICE_REF() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, strided_slice::Prepare,
+ strided_slice::Eval<strided_slice::kReference>};
+ return &r;
+}
+
+// TODO(soroosh): add optimized
+TfLiteRegistration* Register_STRIDED_SLICE() {
+ return Register_STRIDED_SLICE_REF();
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
new file mode 100644
index 0000000000..cd4a364682
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
@@ -0,0 +1,375 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class StridedSliceOpModel : public SingleOpModel {
+ public:
+ StridedSliceOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> begin_shape,
+ std::initializer_list<int> end_shape,
+ std::initializer_list<int> strides_shape, int begin_mask,
+ int end_mask, int ellipsis_mask, int new_axis_mask,
+ int shrink_axis_mask) {
+ input_ = AddInput(TensorType_FLOAT32);
+ begin_ = AddInput(TensorType_INT32);
+ end_ = AddInput(TensorType_INT32);
+ strides_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
+ CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask)
+ .Union());
+ BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ void SetBegin(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(begin_, data);
+ }
+ void SetEnd(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(end_, data);
+ }
+ void SetStrides(std::initializer_list<int32> data) {
+ PopulateTensor<int32>(strides_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int begin_;
+ int end_;
+ int strides_;
+ int output_;
+};
+
+TEST(StridedSliceOpTest, UnsupportedInputSize) {
+ EXPECT_DEATH(
+ StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0),
+ "StridedSlice op only supports 1D-4D input arrays.");
+}
+
+TEST(StridedSliceOpTest, UnssupportedArgs) {
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0),
+ "ellipsis_mask is not implemented yet.");
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0),
+ "new_axis_mask is not implemented yet.");
+ EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 0, 1),
+ "shrink_axis_mask is not implemented yet.");
+}
+
+TEST(StridedSliceOpTest, In1D) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_EmptyOutput) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({10});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBegin) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-5});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeEnd) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({-2});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({5});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+
+TEST(StridedSliceOpTest, In1D_BeginMask) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-2});
+ m.SetEnd({-3});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({5});
+ m.SetEnd({2});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4}));
+}
+
+TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({2});
+ m.SetEnd({-4});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2}));
+}
+
+TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({-3});
+ m.SetEnd({-5});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1}));
+}
+
+TEST(StridedSliceOpTest, In1D_EndMask) {
+ StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4});
+ m.SetBegin({1});
+ m.SetEnd({3});
+ m.SetStrides({1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4}));
+}
+TEST(StridedSliceOpTest, In1D_NegStride) {
+ StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3});
+ m.SetBegin({-1});
+ m.SetEnd({-4});
+ m.SetStrides({-1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1}));
+}
+
+TEST(StridedSliceOpTest, In1D_EvenLenStride2) {
+ StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2});
+ m.SetBegin({0});
+ m.SetEnd({2});
+ m.SetStrides({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
+}
+TEST(StridedSliceOpTest, In1D_OddLenStride2) {
+ StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3});
+ m.SetBegin({0});
+ m.SetEnd({3});
+ m.SetStrides({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_Identity) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+TEST(StridedSliceOpTest, In2D) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5}));
+}
+
+TEST(StridedSliceOpTest, In2D_Stride2) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({0, 0});
+ m.SetEnd({2, 3});
+ m.SetStrides({2, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3}));
+}
+
+TEST(StridedSliceOpTest, In2D_NegStride) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -1});
+ m.SetEnd({2, -4});
+ m.SetStrides({2, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
+}
+
+TEST(StridedSliceOpTest, In2D_BeginMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5}));
+}
+
+TEST(StridedSliceOpTest, In2D_EndMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, 0});
+ m.SetEnd({2, 2});
+ m.SetStrides({1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6}));
+}
+
+TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -2});
+ m.SetEnd({2, -4});
+ m.SetStrides({1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4}));
+}
+TEST(StridedSliceOpTest, In2D_NegStrideEndMask) {
+ StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.SetBegin({1, -2});
+ m.SetEnd({2, -3});
+ m.SetStrides({1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4}));
+}
+
+TEST(StridedSliceOpTest, In3D_Identity) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({1, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));
+}
+
+TEST(StridedSliceOpTest, In3D_NegStride) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({-1, -1, -1});
+ m.SetEnd({-3, -4, -3});
+ m.SetStrides({-1, -1, -1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}));
+}
+TEST(StridedSliceOpTest, In3D_Strided2) {
+ StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0);
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ m.SetBegin({0, 0, 0});
+ m.SetEnd({2, 3, 2});
+ m.SetStrides({2, 2, 2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index b69f2b3e4b..3a58e7ec32 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -49,7 +49,7 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
return matchers;
}
-int SingleOpModel::AddTensor(TensorData t) {
+int SingleOpModel::AddTensor(TensorData t, std::initializer_list<int> data) {
int id = tensors_.size();
// This is slightly different depending on whether we are adding a
@@ -78,8 +78,23 @@ int SingleOpModel::AddTensor(TensorData t) {
builder_.CreateVector<int64_t>({t.zero_point}));
}
- tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>({}),
- t.type, /*buffer=*/0,
+ int buffer_id = 0;
+ if (data.size()) {
+ // Initialize buffers list with empty buffer to allow for non-const tensors.
+ if (buffers_.empty()) {
+ buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
+ }
+
+ // Add data as a Buffer to buffers list.
+ buffer_id = buffers_.size();
+ auto data_buffer =
+ builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()),
+ sizeof(int) * data.size());
+ buffers_.push_back(CreateBuffer(builder_, data_buffer));
+ }
+
+ tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>(t.shape),
+ t.type, /*buffer=*/buffer_id,
/*name=*/0, q_params));
tensor_data_[id] = t;
@@ -88,7 +103,15 @@ int SingleOpModel::AddTensor(TensorData t) {
}
int SingleOpModel::AddInput(const TensorData& t) {
- int id = AddTensor(t);
+ int id = AddTensor(t, {});
+ inputs_.push_back(id);
+ return id;
+}
+
+int SingleOpModel::AddConstInput(TensorType type,
+ std::initializer_list<int> data,
+ std::initializer_list<int> shape) {
+ int id = AddTensor(TensorData{type, shape}, data);
inputs_.push_back(id);
return id;
}
@@ -100,7 +123,7 @@ int SingleOpModel::AddNullInput() {
}
int SingleOpModel::AddOutput(const TensorData& t) {
- int id = AddTensor(t);
+ int id = AddTensor(t, {});
outputs_.push_back(id);
return id;
}
@@ -142,8 +165,7 @@ void SingleOpModel::BuildInterpreter(
subgraphs.push_back(subgraph);
auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs);
- std::vector<flatbuffers::Offset<Buffer>> buffers_vec;
- auto buffers = builder_.CreateVector(buffers_vec);
+ auto buffers = builder_.CreateVector(buffers_);
auto description = builder_.CreateString("programmatic model");
builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
subgraphs_flatbuffer, description, buffers));
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 531c1366a8..cc445299ff 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
#include <vector>
@@ -98,6 +98,10 @@ class SingleOpModel {
int AddInput(TensorType type) { return AddInput(TensorData{type}); }
int AddInput(const TensorData& t);
+ // Add a Tensor containing const data and return the tensor id.
+ int AddConstInput(TensorType type, std::initializer_list<int> data,
+ std::initializer_list<int> shape);
+
// Add a null input tensor (optional input) and return kOptionalTensor.
int AddNullInput();
@@ -181,7 +185,7 @@ class SingleOpModel {
std::unique_ptr<tflite::Interpreter> interpreter_;
private:
- int AddTensor(TensorData t);
+ int AddTensor(TensorData t, std::initializer_list<int> data);
std::map<int, TensorData> tensor_data_;
std::vector<int32_t> inputs_;
@@ -189,6 +193,7 @@ class SingleOpModel {
std::vector<flatbuffers::Offset<Tensor>> tensors_;
std::vector<flatbuffers::Offset<OperatorCode>> opcodes_;
std::vector<flatbuffers::Offset<Operator>> operators_;
+ std::vector<flatbuffers::Offset<Buffer>> buffers_;
std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
};
@@ -197,4 +202,4 @@ template <>
std::vector<string> SingleOpModel::ExtractVector(int index);
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
new file mode 100644
index 0000000000..9cdb58714e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -0,0 +1,527 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unidirectional_sequence_lstm {
+
+// Input Tensors of size {max_time, n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 12; // Optional
+constexpr int kForgetGateBiasTensor = 13;
+constexpr int kCellGateBiasTensor = 14;
+constexpr int kOutputGateBiasTensor = 15;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 16; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 17; // Optional
+
+// Output tensors.
+constexpr int kScratchBufferTensor = 0;
+constexpr int kOutputStateTensor = 1;
+constexpr int kCellStateTensor = 2;
+constexpr int kOutputTensor = 3;
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, params->cell_clip >= 0);
+ TF_LITE_ENSURE(context, params->proj_clip >= 0);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ // TODO(ghodrat): make sure this is correct.
+ const bool projecton_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state and scratch tensors based on the sizes of the input
+// tensors. Also check that the size of the input tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+
+ // Inferring batch size, number of outputs and sequence length and
+ // number of cells from the input tensors.
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
+
+ // Get the pointer to output, state and scratch buffer tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ // TODO(ghodrat): Modify this as soon as we have a finalized method for
+ // scratch buffers.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+
+ // Resize the output and output_state tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
+ output_size->data[0] = max_time;
+ output_size->data[1] = n_batch;
+ output_size->data[2] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
+ output_state_size->data[0] = n_batch;
+ output_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output_state, output_state_size));
+
+ // Resize the scratch buffer tensor.
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ if (use_cifg) {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ } else {
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+ }
+ return kTfLiteOk;
+}
+
+// The LSTM Op engine.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr_time = input->data.f + t * n_batch * n_input;
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
+ n_batch, input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
+ n_batch, output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights->data.f, n_cell, n_input, input_ptr_time,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time,
+ n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights->data.f, n_cell, n_input, input_ptr_time,
+ n_batch, output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights->data.f, n_cell, n_output,
+ output_state->data.f, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
+ cell_state->data.f, n_batch * n_cell,
+ cell_state->data.f);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell,
+ cell_state->data.f);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell,
+ cell_state->data.f);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
+ params->cell_clip, cell_state->data.f);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell,
+ output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights != nullptr);
+ const bool use_projection_bias = (projection_bias != nullptr);
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
+ n_batch, output_ptr_time);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights->data.f, n_output, n_cell, output_gate_scratch,
+ n_batch, output_ptr_time, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_time, n_batch * n_output,
+ params->proj_clip, output_ptr_time);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_time);
+ }
+ tensor_utils::CopyVector(output_ptr_time, n_batch * n_output,
+ output_state->data.f);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace unidirectional_sequence_lstm
+
+TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ unidirectional_sequence_lstm::Prepare,
+ unidirectional_sequence_lstm::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
new file mode 100644
index 0000000000..93b635ae57
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -0,0 +1,1089 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Sequential LSTM op.
+
+#include <iomanip>
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class UnidirectionalLSTMOpModel : public SingleOpModel {
+ public:
+ UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ int sequence_length, bool use_cifg,
+ bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip,
+ float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output),
+ sequence_length_(sequence_length) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ }
+ cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(TensorType_FLOAT32);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ scratch_buffer_ = AddOutput(TensorType_FLOAT32);
+ // TODO(ghodrat): Modify these states when we have a permanent solution for
+ // persistent buffer.
+ output_state_ = AddOutput(TensorType_FLOAT32);
+ cell_state_ = AddOutput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOptions_LSTMOptions,
+ CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
+ cell_clip, proj_clip)
+ .Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void ResetOutputState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(output_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void ResetCellState() {
+ const int zero_buffer_size = n_cell_ * n_batch_;
+ std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
+ memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
+ PopulateTensor(cell_state_, 0, zero_buffer.get(),
+ zero_buffer.get() + zero_buffer_size);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+ int sequence_length() { return sequence_length_; }
+
+ private:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_;
+ int output_state_;
+ int cell_state_;
+ int scratch_buffer_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+ int sequence_length_;
+};
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ UnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
+ -0.15358765, -0.03716109, 0.12507336,
+ 0.41193449, -0.20860538, -0.15053082,
+ 0.09120187, 0.24278517, -0.12222792};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output;
+ float* golden_end =
+ golden_start + lstm.num_outputs() * lstm.sequence_length();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ UnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
+ /*use_peephole=*/true, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
+ -0.05163646, -0.42312205, -0.01218222,
+ 0.24201041, -0.08124574, -0.358325,
+ -0.04621704, 0.21641694, -0.06471302};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ float* golden_start = lstm_golden_output;
+ float* golden_end =
+ golden_start + lstm.num_outputs() * lstm.sequence_length();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+ const int sequence_length = 4;
+
+ UnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/true, /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(
+ {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
+
+ lstm.SetInputToForgetWeights(
+ {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
+ -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
+ -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
+ 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
+ -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
+ -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
+ 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
+ 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
+ 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
+ -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
+ -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
+ -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
+ 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
+ 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
+ -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
+ 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
+
+ lstm.SetInputToCellWeights(
+ {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
+
+ lstm.SetInputToOutputWeights(
+ {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
+
+ lstm.SetInputGateBias(
+ {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
+ -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
+ -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
+ 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
+
+ lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739});
+
+ lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027});
+
+ lstm.SetOutputGateBias(
+ {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
+ 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
+ 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
+ -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
+
+ lstm.SetRecurrentToOutputWeights({
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
+ -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
+ -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
+ -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
+ -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
+ -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
+ 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
+ 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
+ -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
+ -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
+ 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
+ -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
+ 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
+ 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
+ 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
+ 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
+ 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
+ -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
+ 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
+ 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
+ -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
+ -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
+ -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
+ -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
+ -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
+ 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
+ -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
+ 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
+ -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
+ -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
+ 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
+ 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
+ -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
+ 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
+ -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
+ -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
+ -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
+ -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
+ 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
+ -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
+ -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
+ -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
+ 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
+ -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
+ 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
+ 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
+ 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
+ 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
+ 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ });
+
+ lstm.SetCellToInputWeights(
+ {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
+
+ lstm.SetCellToForgetWeights(
+ {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
+
+ lstm.SetCellToOutputWeights(
+ {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
+
+ lstm.SetProjectionWeights(
+ {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
+ 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
+ -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
+ -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
+ 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
+ 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
+ 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
+ -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
+ -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
+ 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
+ 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
+ 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
+ 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
+ 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
+ -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
+ 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
+ -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
+ 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
+ -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
+ -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
+ -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
+ 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
+ -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
+ -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
+ 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
+ -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
+ 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
+ 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
+ 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
+ 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
+ -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
+ 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
+ -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
+ -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
+ 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
+ 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
+ -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
+ -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
+ 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
+ -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
+ 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
+ -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
+ -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
+ 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
+ -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
+ -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
+ 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
+ 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
+ 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
+
+ static float lstm_input[][20] = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
+ 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
+ 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
+ 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
+ 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
+
+ static float lstm_golden_output[][64] = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ for (int i = 0; i < lstm.sequence_length(); i++) {
+ float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
+ float* batch0_end = batch0_start + lstm.num_inputs();
+
+ lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
+
+ float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
+ float* batch1_end = batch1_start + lstm.num_inputs();
+ lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
+ }
+
+ lstm.Invoke();
+
+ std::vector<float> expected;
+ for (int i = 0; i < lstm.sequence_length(); i++) {
+ float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
+ float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
+ float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
+ float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
+ expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
+ expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
+ }
+ EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 85e09049ee..f5f1ec2cf3 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -34,7 +34,7 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int KHiddenStateTensor = 0;
+constexpr int kHiddenStateTensor = 0;
constexpr int kOutputTensor = 1;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@@ -51,8 +51,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check all the parameters of tensor match within themselves and match the
// input configuration.
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int num_units = input_weights->dims->data[0];
TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]);
TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
@@ -60,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ &context->tensors[node->outputs->data[kHiddenStateTensor]];
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
// Resize state.
@@ -75,8 +79,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
- output_size_array->data[0] = batch_size;
- output_size_array->data[1] = max_time;
+ output_size_array->data[0] = (time_major) ? max_time : batch_size;
+ output_size_array->data[1] = (time_major) ? batch_size : max_time;
output_size_array->data[2] = num_units;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output,
output_size_array));
@@ -84,8 +88,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+namespace {
+void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int num_units, int input_weights_stride,
+ int recurrent_weights_stride, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ // Output = bias
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = bias_ptr[o];
+ }
+
+ // Output += input * input_weights
+ for (int o = 0; o < num_units; o++) {
+ for (int i = 0; i < input_size; i++) {
+ output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
+ }
+ input_weights_ptr += input_weights_stride;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ for (int o = 0; o < num_units; o++) {
+ for (int h = 0; h < num_units; h++) {
+ output_ptr_batch[o] +=
+ hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ }
+ recurrent_weights_ptr += recurrent_weights_stride;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
+ hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+}
+} // namespace
+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
TfLiteTensor* input_weights =
@@ -94,61 +134,60 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
&context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ &context->tensors[node->outputs->data[kHiddenStateTensor]];
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
// Initialize the pointer bias.
const float* bias_ptr = bias->data.f;
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[2];
const int input_weights_stride = input_weights->dims->data[1];
const int recurrent_weights_stride = recurrent_weights->dims->data[1];
- // For each batch
- for (int b = 0; b < batch_size; b++) {
- // Initialize the pointer to hidden state.
- float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
- for (int s = 0; s < max_time; s++) {
- // Initialize the pointer to input and output.
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- float* output_ptr_batch =
- output->data.f + b * num_units * max_time + s * num_units;
-
- // Initialize input_weights and recurrent_weights.
- const float* input_weights_ptr = input_weights->data.f;
- const float* recurrent_weights_ptr = recurrent_weights->data.f;
-
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
+ if (time_major) {
+ // Unroll the sequence
+ for (int s = 0; s < max_time; s++) {
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size + b * input_size;
+ float* output_ptr_batch =
+ output->data.f + s * num_units * batch_size + b * num_units;
+
+ RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
+ bias_ptr, input_size, num_units, input_weights_stride,
+ recurrent_weights_stride, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] =
- (ActivationFunctor(params->activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+ } else {
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ float* output_ptr_batch =
+ output->data.f + b * num_units * max_time + s * num_units;
+
+ RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
+ bias_ptr, input_size, num_units, input_weights_stride,
+ recurrent_weights_stride, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
}
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index a1c1eda160..82c680ec3d 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Unit test for TFLite RNN op.
+// Unit test for TFLite Sequential RNN op.
#include <vector>
#include <iomanip>
@@ -125,7 +125,8 @@ static float rnn_golden_output[] = {
class UnidirectionalRNNOpModel : public SingleOpModel {
public:
- UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size)
+ UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size,
+ bool time_major)
: batches_(batches),
sequence_len_(sequence_len),
units_(units),
@@ -136,13 +137,22 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions,
- CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, sequence_len_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOptions_SequenceRNNOptions,
+ CreateSequenceRNNOptions(builder_, time_major,
+ ActivationFunctionType_RELU)
+ .Union());
+ if (time_major) {
+ BuildInterpreter({{sequence_len_, batches_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ } else {
+ BuildInterpreter({{batches_, sequence_len_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ }
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -195,7 +205,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
// TODO(mirkov): add another test which directly compares to TF once TOCO
// supports the conversion from dynamic_rnn with BasicRNNCell.
TEST(FullyConnectedOpTest, BlackBoxTest) {
- UnidirectionalRNNOpModel rnn(2, 16, 16, 8);
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/false);
rnn.SetWeights(
{0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
@@ -260,6 +271,77 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
+TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/true);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
+ -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
+ 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
+ -0.37609905});
+
+ rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1});
+
+ rnn.ResetHiddenState();
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_batch_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_batch_end = golden_batch_start + rnn.num_units();
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ }
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index b11d86c375..5cd6c20850 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
+#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
#include "tensorflow/contrib/lite/context.h"
@@ -42,4 +42,4 @@ class MemoryPlanner {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 86e613736d..415d984ad8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -30,17 +30,6 @@ limitations under the License.
namespace tflite {
-namespace {
-inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) {
- ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
- if (VerifyModelBuffer(verifier)) {
- return ::tflite::GetModel(buf);
- } else {
- return nullptr;
- }
-}
-} // namespace
-
const char* kEmptyTensorName = "";
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
@@ -82,7 +71,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
}
if (!allocation_->valid() || !CheckModelIdentifier()) return;
- model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
+ model_ = ::tflite::GetModel(allocation_->base());
}
bool FlatBufferModel::CheckModelIdentifier() const {
@@ -103,7 +92,7 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
if (!allocation_->valid()) return;
- model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
+ model_ = ::tflite::GetModel(allocation_->base());
}
FlatBufferModel::FlatBufferModel(const Model* model,
@@ -339,7 +328,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RNN: {
TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
@@ -453,6 +452,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
@@ -468,32 +468,11 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
auto* params = MallocPOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
- params->new_height = schema_params->new_height();
- params->new_width = schema_params->new_width();
}
builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_PAD: {
- auto* params = MallocPOD<TfLitePadParams>();
- if (auto* schema_params = op->builtin_options_as_PadOptions()) {
- auto* before_padding = schema_params->before_padding();
- FlatBufferIntVectorToArray(sizeof(params->before_padding),
- before_padding, params->before_padding,
- error_reporter);
-
- auto* after_padding = schema_params->after_padding();
- FlatBufferIntVectorToArray(sizeof(params->after_padding), after_padding,
- params->after_padding, error_reporter);
-
- if (before_padding->Length() != after_padding->Length()) {
- error_reporter->Report(
- "Before padding and after padding arrays need to contain the "
- "same number of dimensions.\n");
- }
- params->num_dimensions = after_padding->Length();
- }
- builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_RESHAPE: {
@@ -607,6 +586,18 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
}
return builtin_data;
}
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index e0c96f7f04..a467df5bb4 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -31,8 +31,8 @@ limitations under the License.
// OpResolver must be defined to provide your kernel implementations to the
// interpreter. This is environment specific and may consist of just the builtin
// ops, or some custom operators you defined to extend tflite.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
#include "tensorflow/contrib/lite/error_reporter.h"
@@ -173,4 +173,4 @@ class InterpreterBuilder {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index 5330c8f594..66f22fd66a 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
-#include <string>
#include "tensorflow/contrib/lite/model.h"
@@ -247,14 +246,6 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) {
ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
}
-// Test what happens if we cannot bind any of the ops.
-TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) {
- std::string corrupted_data = "123";
- auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(),
- corrupted_data.length());
- ASSERT_FALSE(model);
-}
-
// Test that loading model directly from a Model flatbuffer works.
TEST(BasicFlatBufferModel, TestBuildFromModel) {
TestErrorReporter reporter;
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
index d17323a3f9..90260c8d62 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor.h
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#define TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
#include <string>
#include <vector>
@@ -77,4 +77,4 @@ struct SmartReplyConfig {
} // namespace custom
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
index 97d3c650e2..e6c8d966f1 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -22,8 +22,9 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
+//#include "tensorflow/contrib/lite/models/test_utils.h"
#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/core/platform/test.h"
namespace tflite {
namespace custom {
@@ -33,6 +34,11 @@ namespace {
const char kModelName[] = "smartreply_ondevice_model.bin";
const char kSamples[] = "smartreply_samples.tsv";
+string TestDataPath() {
+ return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
+ "contrib/lite/models/testdata/"));
+}
+
MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
bool has_expected_response = false;
for (const auto &item : *arg) {
diff --git a/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc
deleted file mode 100644
index bf95b313f3..0000000000
--- a/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc
+++ /dev/null
@@ -1,127 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech ASR AM model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "file/base/path.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-
-namespace tflite {
-namespace models {
-
-constexpr int kModelInputTensor = 0;
-constexpr int kLstmLayer1OutputStateTensor = 19;
-constexpr int kLstmLayer1CellStateTensor = 20;
-constexpr int kLstmLayer2OutputStateTensor = 40;
-constexpr int kLstmLayer2CellStateTensor = 41;
-constexpr int kLstmLayer3OutputStateTensor = 61;
-constexpr int kLstmLayer3CellStateTensor = 62;
-constexpr int kLstmLayer4OutputStateTensor = 82;
-constexpr int kLstmLayer4CellStateTensor = 83;
-constexpr int kLstmLayer5OutputStateTensor = 103;
-constexpr int kLstmLayer5CellStateTensor = 104;
-constexpr int kModelOutputTensor = 109;
-
-TEST(SpeechAsrAm, RandomIOTest) {
- // Read the model.
- string tflite_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_am_model.tflite");
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to mmap model " << tflite_file_path;
-
- // Initialize the interpreter.
- ops::builtin::BuiltinOpResolver builtins;
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, builtins)(&interpreter);
- CHECK(interpreter != nullptr);
- interpreter->AllocateTensors();
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_am_model_in.csv");
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_am_model_out.csv");
- ReadFrames(output_file_path, &output_frames);
-
- const int speech_batch_size =
- interpreter->tensor(kModelInputTensor)->dims->data[0];
- const int speech_input_size =
- interpreter->tensor(kModelInputTensor)->dims->data[1];
- const int speech_output_size =
- interpreter->tensor(kModelOutputTensor)->dims->data[1];
-
- float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
- float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
-
- // Clear the LSTM state for layers.
- memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer4CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer5CellStateTensor)->bytes);
-
-
- for (int i = 0; i < input_frames.size(); i++) {
- // Feed the input to model.
- int frame_ptr = 0;
- for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
- input_ptr[k] = input_frames[i][frame_ptr++];
- }
- // Run the model.
- interpreter->Invoke();
- // Validate the output.
- for (int k = 0; k < speech_output_size; k++) {
- ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4);
- }
- }
-}
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc
deleted file mode 100644
index 53f2b66da4..0000000000
--- a/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech ASR LM model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "file/base/path.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-
-namespace tflite {
-namespace models {
-
-constexpr int kModelInput1Tensor = 0;
-constexpr int kModelInput2Tensor = 66;
-constexpr int kLstmLayer1OutputStateTensor = 21;
-constexpr int kLstmLayer1CellStateTensor = 22;
-constexpr int kLstmLayer2OutputStateTensor = 42;
-constexpr int kLstmLayer2CellStateTensor = 43;
-constexpr int kLstmLayer3OutputStateTensor = 63;
-constexpr int kLstmLayer3CellStateTensor = 64;
-constexpr int kModelOutputTensor = 75;
-
-static void ClearLstmStates(Interpreter* interpreter) {
- memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
-}
-
-TEST(SpeechAsrLm, EndToEndTest) {
- // Read the model.
- string tflite_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_lm_model.tflite");
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to mmap model " << tflite_file_path;
-
- // Initialize the interpreter.
- ops::builtin::BuiltinOpResolver builtins;
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, builtins)(&interpreter);
- CHECK(interpreter != nullptr);
- interpreter->AllocateTensors();
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_lm_model_in.csv");
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path =
- file::JoinPath(TestDataPath(), "speech_asr_lm_model_out.csv");
- ReadFrames(output_file_path, &output_frames);
-
- CHECK_EQ(interpreter->tensor(kModelInput1Tensor)->dims->size, 1);
- const int input1_size =
- interpreter->tensor(kModelInput1Tensor)->dims->data[0];
- CHECK_EQ(input1_size, 1);
- CHECK_EQ(interpreter->tensor(kModelInput2Tensor)->dims->size, 1);
- const int output_size =
- interpreter->tensor(kModelOutputTensor)->dims->data[0];
- CHECK_EQ(output_size, 1);
-
- int* input_lookup_ptr = interpreter->tensor(kModelInput1Tensor)->data.i32;
- int* output_lookup_ptr = interpreter->tensor(kModelInput2Tensor)->data.i32;
- float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
-
-
- for (int i = 0; i < input_frames.size(); i++) {
- float output_score = 0.0f;
- // Reset LSTM states for each sequence.
- ClearLstmStates(interpreter.get());
- // For subsequent inputs feed them sequentially, one-by-one.
- for (int k = 1; k < input_frames[i].size(); k++) {
- // Feed the inputs to model.
- input_lookup_ptr[0] = static_cast<int32>(input_frames[i][k - 1]);
- output_lookup_ptr[0] = static_cast<int32>(input_frames[i][k]);
- // Run the model.
- interpreter->Invoke();
- // Sum up the outputs.
- output_score += output_ptr[0];
- }
- // Validate the output.
- ASSERT_NEAR(output_score, output_frames[i][0], 1.4e-5);
- }
-}
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc b/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc
deleted file mode 100644
index f7e136113a..0000000000
--- a/tensorflow/contrib/lite/models/speech_endpointer_model_test.cc
+++ /dev/null
@@ -1,104 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech EndPointer model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-
-namespace tflite {
-namespace models {
-
-constexpr int kModelInputTensor = 0;
-constexpr int kLstmLayer1OutputStateTensor = 28;
-constexpr int kLstmLayer1CellStateTensor = 29;
-constexpr int kLstmLayer2OutputStateTensor = 49;
-constexpr int kLstmLayer2CellStateTensor = 50;
-constexpr int kModelOutputTensor = 58;
-
-TEST(SpeechEndpointer, EndpointerTest) {
- // Read the model.
- string tflite_file_path =
- StrCat(TestDataPath(), "/", "speech_endpointer_model.tflite");
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to read model from file " << tflite_file_path;
-
- // Initialize the interpreter.
- ops::builtin::BuiltinOpResolver builtins;
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, builtins)(&interpreter);
- CHECK(interpreter != nullptr);
- interpreter->AllocateTensors();
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path =
- StrCat(TestDataPath(), "/", "speech_endpointer_model_in.csv");
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path =
- StrCat(TestDataPath(), "/", "speech_endpointer_model_out.csv");
- ReadFrames(output_file_path, &output_frames);
-
- const int speech_batch_size =
- interpreter->tensor(kModelInputTensor)->dims->data[0];
- const int speech_input_size =
- interpreter->tensor(kModelInputTensor)->dims->data[1];
- const int speech_output_size =
- interpreter->tensor(kModelOutputTensor)->dims->data[1];
-
- float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
- float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
-
- // Clear the LSTM state for layers.
- memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
-
- for (int i = 0; i < input_frames.size(); i++) {
- // Feed the input to model.
- int frame_ptr = 0;
- for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
- input_ptr[k] = input_frames[i][frame_ptr++];
- }
- // Run the model.
- interpreter->Invoke();
- // Validate the output.
- for (int k = 0; k < speech_output_size; k++) {
- ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
- }
- }
-}
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
deleted file mode 100644
index f69cae8d2c..0000000000
--- a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc
+++ /dev/null
@@ -1,114 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech Hotword model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-
-namespace tflite {
-namespace models {
-
-void RunTest(int model_input_tensor, int svdf_layer_state_tensor,
- int model_output_tensor, const string& model_name,
- const string& golden_in_name, const string& golden_out_name) {
- // Read the model.
- string tflite_file_path = StrCat(TestDataPath(), "/", model_name);
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to read model from file " << tflite_file_path;
-
- // Initialize the interpreter.
- ops::builtin::BuiltinOpResolver builtins;
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, builtins)(&interpreter);
- CHECK(interpreter != nullptr);
- interpreter->AllocateTensors();
-
- // Reset the SVDF layer state.
- memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0,
- interpreter->tensor(svdf_layer_state_tensor)->bytes);
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path = StrCat(TestDataPath(), "/", golden_in_name);
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path = StrCat(TestDataPath(), "/", golden_out_name);
- ReadFrames(output_file_path, &output_frames);
-
- const int speech_batch_size =
- interpreter->tensor(model_input_tensor)->dims->data[0];
- const int speech_input_size =
- interpreter->tensor(model_input_tensor)->dims->data[1];
- const int speech_output_size =
- interpreter->tensor(model_output_tensor)->dims->data[1];
- const int input_sequence_size =
- input_frames[0].size() / (speech_input_size * speech_batch_size);
- float* input_ptr = interpreter->tensor(model_input_tensor)->data.f;
- float* output_ptr = interpreter->tensor(model_output_tensor)->data.f;
-
- // The first layer (SVDF) input size is 40 (speech_input_size). Each speech
- // input frames for this model is 1600 floats, which can be fed to input in a
- // sequence of size 40 (input_sequence_size).
- for (int i = 0; i < TestInputSize(input_frames); i++) {
- int frame_ptr = 0;
- for (int s = 0; s < input_sequence_size; s++) {
- for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
- input_ptr[k] = input_frames[i][frame_ptr++];
- }
- interpreter->Invoke();
- }
- // After the whole frame (1280 floats) is fed, we can check the output frame
- // matches with the golden output frame.
- for (int k = 0; k < speech_output_size; k++) {
- ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
- }
- }
-}
-
-TEST(SpeechHotword, OkGoogleTestRank1) {
- constexpr int kModelInputTensor = 0;
- constexpr int kSvdfLayerStateTensor = 4;
- constexpr int kModelOutputTensor = 18;
-
- RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
- "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
- "speech_hotword_model_out_rank1.csv");
-}
-
-TEST(SpeechHotword, OkGoogleTestRank2) {
- constexpr int kModelInputTensor = 17;
- constexpr int kSvdfLayerStateTensor = 1;
- constexpr int kModelOutputTensor = 18;
- RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor,
- "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
- "speech_hotword_model_out_rank2.csv");
-}
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
deleted file mode 100644
index e208fac8df..0000000000
--- a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech SpeakerId model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
-
-void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
-
-namespace tflite {
-namespace models {
-
-constexpr int kModelInputTensor = 0;
-constexpr int kLstmLayer1OutputStateTensor = 19;
-constexpr int kLstmLayer1CellStateTensor = 20;
-constexpr int kLstmLayer2OutputStateTensor = 40;
-constexpr int kLstmLayer2CellStateTensor = 41;
-constexpr int kLstmLayer3OutputStateTensor = 61;
-constexpr int kLstmLayer3CellStateTensor = 62;
-constexpr int kModelOutputTensor = 66;
-
-void SpeakerIdTest(bool useNNAPI) {
- // Read the model.
- string tflite_file_path =
- StrCat(TestDataPath(), "/", "speech_speakerid_model.tflite");
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to read model from file " << tflite_file_path;
-
- // Initialize the interpreter.
- ::tflite::MutableOpResolver resolver;
- RegisterSelectedOps(&resolver);
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, resolver)(&interpreter);
- CHECK(interpreter != nullptr);
-
- interpreter->UseNNAPI(useNNAPI);
-
- interpreter->AllocateTensors();
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path =
- StrCat(TestDataPath(), "/", "speech_speakerid_model_in.csv");
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path =
- StrCat(TestDataPath(), "/", "speech_speakerid_model_out.csv");
- ReadFrames(output_file_path, &output_frames);
-
- const int speech_batch_size =
- interpreter->tensor(kModelInputTensor)->dims->data[0];
- const int speech_input_size =
- interpreter->tensor(kModelInputTensor)->dims->data[1];
- const int speech_output_size =
- interpreter->tensor(kModelOutputTensor)->dims->data[1];
-
- float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
- float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
-
- // Clear the LSTM state for layers.
- memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
- for (int i = 0; i < input_frames.size(); i++) {
- // Feed the input to model.
- int frame_ptr = 0;
- for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
- input_ptr[k] = input_frames[i][frame_ptr++];
- }
- // Run the model.
- interpreter->Invoke();
- // Validate the output.
- for (int k = 0; k < speech_output_size; k++) {
- ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
- }
- }
-}
-
-TEST(SpeechSpeakerId, OkGoogleTest) { SpeakerIdTest(false); }
-
-TEST(SpeechSpeakerId, OkGoogleTestUsingNNAPI) { SpeakerIdTest(true); }
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
new file mode 100644
index 0000000000..daa8c3100b
--- /dev/null
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for speech models (Hotword, SpeakerId) using TFLite Ops.
+
+#include <memory>
+#include <string>
+
+#include <fstream>
+
+#include "testing/base/public/googletest.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include "tensorflow/contrib/lite/testing/split.h"
+#include "tensorflow/contrib/lite/testing/tflite_driver.h"
+
+namespace tflite {
+namespace {
+
+const char kDataPath[] = "third_party/tensorflow/contrib/lite/models/testdata/";
+
+bool Init(const string& in_file_name, testing::TfLiteDriver* driver,
+ std::ifstream* in_file) {
+ driver->SetModelBaseDir(kDataPath);
+ in_file->open(string(kDataPath) + in_file_name, std::ifstream::in);
+ return in_file->is_open();
+}
+
+// Converts a set of test files provided by the speech team into a single
+// test_spec. Input CSV files are supposed to contain a number of sequences per
+// line. Each sequence maps to a single invocation of the interpreter and the
+// output tensor after all sequences have run is compared to the corresponding
+// line in the output CSV file.
+bool ConvertCsvData(const string& model_name, const string& in_name,
+ const string& out_name, const string& input_tensor,
+ const string& output_tensor,
+ const string& persistent_tensors, int sequence_size,
+ std::ostream* out) {
+ auto data_path = [](const string& s) { return string(kDataPath) + s; };
+
+ *out << "load_model: \"" << data_path(model_name) << "\"" << std::endl;
+
+ *out << "init_state: \"" << persistent_tensors << "\"" << std::endl;
+
+ string in_file_name = data_path(in_name);
+ std::ifstream in_file(in_file_name);
+ if (!in_file.is_open()) {
+ std::cerr << "Failed to open " << in_file_name << std::endl;
+ return false;
+ }
+ string out_file_name = data_path(out_name);
+ std::ifstream out_file(out_file_name);
+ if (!out_file.is_open()) {
+ std::cerr << "Failed to open " << out_file_name << std::endl;
+ return false;
+ }
+
+ int invocation_count = 0;
+ string in_values;
+ while (std::getline(in_file, in_values, '\n')) {
+ std::vector<string> input = testing::Split<string>(in_values, ",");
+ int num_sequences = input.size() / sequence_size;
+
+ for (int j = 0; j < num_sequences; ++j) {
+ *out << "invoke {" << std::endl;
+ *out << " id: " << invocation_count << std::endl;
+ *out << " input: \"";
+ for (int k = 0; k < sequence_size; ++k) {
+ *out << input[k + j * sequence_size] << ",";
+ }
+ *out << "\"" << std::endl;
+
+ if (j == num_sequences - 1) {
+ string out_values;
+ if (!std::getline(out_file, out_values, '\n')) {
+ std::cerr << "Not enough lines in " << out_file_name << std::endl;
+ return false;
+ }
+ *out << " output: \"" << out_values << "\"" << std::endl;
+ }
+
+ *out << "}" << std::endl;
+ ++invocation_count;
+ }
+ }
+ return true;
+}
+
+TEST(SpeechTest, HotwordOkGoogleRank1Test) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank1.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"18", /*persistent_tensors=*/"4",
+ /*sequence_size=*/40, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, HotwordOkGoogleRank2Test) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
+ "speech_hotword_model_out_rank2.csv", /*input_tensor=*/"17",
+ /*output_tensor=*/"18", /*persistent_tensors=*/"1",
+ /*sequence_size=*/40, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, SpeakerIdOkGoogleTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
+ "speech_speakerid_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"66",
+ /*persistent_tensors=*/"19,20,40,41,61,62",
+ /*sequence_size=*/80, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, AsrAmTest) {
+ std::stringstream os;
+ ASSERT_TRUE(
+ ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
+ "speech_asr_am_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"109",
+ /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
+ /*sequence_size=*/320, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+// The original version of speech_asr_lm_model_test.cc ran a few sequences
+// through the interpreter and stored the sum of all the output, which was them
+// compared for correctness. In this test we are comparing all the intermediate
+// results.
+TEST(SpeechTest, AsrLmTest) {
+ std::ifstream in_file;
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
+ ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, EndpointerTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData(
+ "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
+ "speech_endpointer_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"58",
+ /*persistent_tensors=*/"28,29,49,50",
+ /*sequence_size=*/320, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+TEST(SpeechTest, TtsTest) {
+ std::stringstream os;
+ ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
+ "speech_tts_model_in.csv",
+ "speech_tts_model_out.csv", /*input_tensor=*/"0",
+ /*output_tensor=*/"74",
+ /*persistent_tensors=*/"25,26,46,47,67,68,73",
+ /*sequence_size=*/334, &os));
+ testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
+ ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver))
+ << test_driver.GetErrorMessage();
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc
deleted file mode 100644
index 8829177689..0000000000
--- a/tensorflow/contrib/lite/models/speech_tts_model_test.cc
+++ /dev/null
@@ -1,116 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for speech TTS model using TFLite Ops.
-
-#include <string.h>
-
-#include <memory>
-#include <string>
-
-#include "base/logging.h"
-#include "testing/base/public/googletest.h"
-#include <gtest/gtest.h>
-#include "absl/strings/str_cat.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/models/test_utils.h"
-
-namespace tflite {
-namespace models {
-
-constexpr int kModelInputTensor = 0;
-constexpr int kLstmLayer1OutputStateTensor = 25;
-constexpr int kLstmLayer1CellStateTensor = 26;
-constexpr int kLstmLayer2OutputStateTensor = 46;
-constexpr int kLstmLayer2CellStateTensor = 47;
-constexpr int kLstmLayer3OutputStateTensor = 67;
-constexpr int kLstmLayer3CellStateTensor = 68;
-constexpr int kRnnLayerHiddenStateTensor = 73;
-constexpr int kModelOutputTensor = 74;
-
-TEST(SpeechTTS, RandomIOTest) {
- // Read the model.
- string tflite_file_path =
- StrCat(TestDataPath(), "/", "speech_tts_model.tflite");
- auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
- CHECK(model) << "Failed to mmap model " << tflite_file_path;
-
- // Initialize the interpreter.
- ops::builtin::BuiltinOpResolver builtins;
- std::unique_ptr<Interpreter> interpreter;
- InterpreterBuilder(*model, builtins)(&interpreter);
- CHECK(interpreter != nullptr);
- interpreter->AllocateTensors();
-
- // Load the input frames.
- Frames input_frames;
- const string input_file_path =
- StrCat(TestDataPath(), "/", "speech_tts_model_in.csv");
- ReadFrames(input_file_path, &input_frames);
-
- // Load the golden output results.
- Frames output_frames;
- const string output_file_path =
- StrCat(TestDataPath(), "/", "speech_tts_model_out.csv");
- ReadFrames(output_file_path, &output_frames);
-
- const int speech_batch_size =
- interpreter->tensor(kModelInputTensor)->dims->data[0];
- const int speech_input_size =
- interpreter->tensor(kModelInputTensor)->dims->data[1];
- const int speech_output_size =
- interpreter->tensor(kModelOutputTensor)->dims->data[1];
-
- float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f;
- float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f;
-
- // Clear the LSTM state for layers.
- memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer1CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer2CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes);
- memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0,
- interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
-
- memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0,
- interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes);
-
- for (int i = 0; i < input_frames.size(); i++) {
- // Feed the input to model.
- int frame_ptr = 0;
- for (int k = 0; k < speech_input_size * speech_batch_size; k++) {
- input_ptr[k] = input_frames[i][frame_ptr++];
- }
- // Run the model.
- interpreter->Invoke();
- // Validate the output.
- for (int k = 0; k < speech_output_size; k++) {
- ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5);
- }
- }
-}
-
-} // namespace models
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h
deleted file mode 100644
index 1e14c26a35..0000000000
--- a/tensorflow/contrib/lite/models/test_utils.h
+++ /dev/null
@@ -1,84 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
-
-#include <stdlib.h>
-#include <string.h>
-
-#include <fstream>
-#include <memory>
-#include <string>
-#include <vector>
-
-namespace tflite {
-namespace models {
-using Frames = std::vector<std::vector<float>>;
-} // namespace models
-} // namespace tflite
-
-#ifndef __ANDROID__
-#include "absl/strings/str_cat.h"
-#include "tensorflow/core/platform/test.h"
-
-inline string TestDataPath() {
- return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
- "contrib/lite/models/testdata/"));
-}
-inline int TestInputSize(const tflite::models::Frames& input_frames) {
- return input_frames.size();
-}
-#else
-inline string TestDataPath() {
- return string("third_party/tensorflow/contrib/lite/models/testdata/");
-}
-
-inline int TestInputSize(const tflite::models::Frames& input_frames) {
- // Android TAP is very slow, we only test the first 20 frames.
- return 20;
-}
-#endif
-
-namespace tflite {
-namespace models {
-
-// Read float data from a comma-separated file:
-// Each line will be read into a float vector.
-// The return result will be a vector of float vectors.
-void ReadFrames(const string& csv_file_path, Frames* frames) {
- std::ifstream csv_file(csv_file_path);
- string line;
- while (std::getline(csv_file, line, '\n')) {
- std::vector<float> fields;
- // Used by strtok_r internaly for successive calls on the same string.
- char* save_ptr = nullptr;
-
- // Tokenize the line.
- char* next_token =
- strtok_r(const_cast<char*>(line.c_str()), ",", &save_ptr);
- while (next_token != nullptr) {
- float f = strtod(next_token, nullptr);
- fields.push_back(f);
- next_token = strtok_r(nullptr, ",", &save_ptr);
- }
- frames->push_back(fields);
- }
- csv_file.close();
-}
-
-} // namespace models
-} // namespace tflite
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
index 667a588383..1c47e00aae 100644
--- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
@@ -53,7 +53,7 @@ with the corresponding parameters as shown in the figure.
### Automatic Speech Recognizer (ASR) Acoustic Model (AM)
The acoustic model for automatic speech recognition is the neural network model
-for matching phonemes to the input autio features. It generates posterior
+for matching phonemes to the input audio features. It generates posterior
probabilities of phonemes from speech frontend features (log-mel filterbanks).
It has an input size of 320 (float), an output size of 42 (float), five LSTM
layers and one fully connected layers with a Softmax activation function, with
@@ -68,7 +68,7 @@ for predicting the probability of a word given previous words in a sentence.
It generates posterior probabilities of the next word based from a sequence of
words. The words are encoded as indices in a fixed size dictionary.
The model has two inputs both of size one (integer): the current word index and
-next word index, an output size of one (float): the log probability. It consits
+next word index, an output size of one (float): the log probability. It consists
of three embedding layer, three LSTM layers, followed by a multiplication, a
fully connected layers and an addition.
The corresponding parameters as shown in the figure.
diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
new file mode 100644
index 0000000000..5812de4b30
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
@@ -0,0 +1,202 @@
+load_model: "speech_asr_lm_model.tflite"
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 3
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 4
+ input: "8409"
+ input: "1488"
+ output: "0.601841"
+}
+invoke {
+ id: 5
+ input: "1488"
+ input: "63981"
+ output: "-0.314846"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 6
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 7
+ input: "8409"
+ input: "3082"
+ output: "-3.63721"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 8
+ input: "63982"
+ input: "8409"
+ output: "-2.75389"
+}
+invoke {
+ id: 9
+ input: "8409"
+ input: "18965"
+ output: "-6.93985"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 13
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 14
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 15
+ input: "914"
+ input: "63981"
+ output: "-3.82091"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 19
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 20
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 21
+ input: "914"
+ input: "48619"
+ output: "-4.02131"
+}
+invoke {
+ id: 22
+ input: "48619"
+ input: "63981"
+ output: "-0.677399"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 26
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 27
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+}
+invoke {
+ id: 28
+ input: "914"
+ input: "4700"
+ output: "-4.056"
+}
+invoke {
+ id: 29
+ input: "4700"
+ input: "63981"
+ output: "0.415889"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 30
+ input: "63982"
+ input: "12516"
+ output: "-6.20867"
+}
+invoke {
+ id: 31
+ input: "12516"
+ input: "914"
+ output: "-0.407277"
+invoke {
+ id: 32
+ input: "914"
+ input: "51923"
+ output: "-14.1147"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 34
+ input: "63982"
+ input: "5520"
+ output: "-4.56971"
+}
+invoke {
+ id: 35
+ input: "5520"
+ input: "16318"
+ output: "-1.54815"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 36
+ input: "63982"
+ input: "5520"
+ output: "-4.56971"
+}
+invoke {
+ id: 37
+ input: "5520"
+ input: "28303"
+ output: "-14.0947"
+}
+init_state: "21,22,42,43,63,64"
+invoke {
+ id: 38
+ input: "63982"
+ input: "12451"
+ output: "-6.24243"
+}
+invoke {
+ id: 39
+ input: "12451"
+ input: "752"
+ output: "0.0700736"
+}
+invoke {
+ id: 40
+ input: "752"
+ input: "11"
+ output: "-1.72744"
+}
+invoke {
+ id: 41
+ input: "11"
+ input: "19454"
+ output: "-3.19211"
+}
+invoke {
+ id: 42
+ input: "19454"
+ input: "16989"
+ output: "-4.01684"
+}
+invoke {
+ id: 43
+ input: "16989"
+ input: "40168"
+ output: "-8.91317"
+}
+invoke {
+ id: 44
+ input: "40168"
+ input: "63981"
+ output: "-0.675377"
+}
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 3cda4bcccc..7019c29959 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -370,7 +370,7 @@ enum {
* Looks up items from a given tensor.
*
* Each item in the output is a raw copy of the corresponding item in
- * the input “values”. If the the given “lookup” indices are out of bounds,
+ * the input “values”. If the given “lookup” indices are out of bounds,
* the op will fail and an error will be reported.
*
* Inputs:
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index b3602f799e..d5b9319407 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -322,6 +322,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
+ case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case tflite::BuiltinOperator_L2_NORMALIZATION:
case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
case tflite::BuiltinOperator_MUL:
@@ -338,6 +339,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_DIV:
case tflite::BuiltinOperator_SUB:
case tflite::BuiltinOperator_SQUEEZE:
+ case tflite::BuiltinOperator_STRIDED_SLICE:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index f29aa9e18e..e98000929a 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
@@ -63,4 +63,4 @@ class NNAPIDelegate {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
+#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
index 54d4876095..1b6998cda3 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.h
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Optional debugging functionality. For small sized binaries, these are not
// needed.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#define TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
#include "tensorflow/contrib/lite/interpreter.h"
@@ -29,4 +29,4 @@ TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 4d87a5907b..3c369774be 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -31,10 +31,18 @@ import tempfile
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
-from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+# Lazy load since some of the performance benchmark skylark rules
+# break dependencies.
+_toco_python = LazyLoader(
+ "tensorflow_wrap_toco", globals(),
+ "tensorflow.contrib.lite.toco.python."
+ "tensorflow_wrap_toco")
+del LazyLoader
# Enum types from the protobuf promoted to the API
FLOAT = _types_pb2.FLOAT
@@ -86,7 +94,8 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
# TODO(aselle): When toco does not use fatal errors for failure, we can
# switch this on.
if not _toco_from_proto_bin:
- return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str)
+ return _toco_python.TocoConvert(
+ model_flags_str, toco_flags_str, input_data_str)
with tempfile.NamedTemporaryFile() as fp_toco, \
tempfile.NamedTemporaryFile() as fp_model, \
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index f5251031b3..ec202cd407 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -53,12 +53,12 @@ table Tensor {
type:TensorType;
// An index that refers to the buffers table at the root of the model. Or,
// if there is no data buffer associated (i.e. intermediate results), then
- // this is 0 (which refers to an always existant empty buffer).
+ // this is 0 (which refers to an always existent empty buffer).
//
// The data_buffer itself is an opaque container, with the assumption that the
// target device is little-endian. In addition, all builtin operators assume
// the memory is ordered such that if `shape` is [4, 3, 2], then index
- // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k].
+ // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
buffer:uint;
name:string; // For debugging and importing back into tensorflow.
quantization:QuantizationParameters; // Optional.
@@ -117,6 +117,8 @@ enum BuiltinOperator : byte {
SUB = 41,
DIV = 42,
SQUEEZE = 43,
+ UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ STRIDED_SLICE = 45,
}
// Options for the builtin operators.
@@ -151,6 +153,8 @@ union BuiltinOptions {
SubOptions,
DivOptions,
SqueezeOptions,
+ SequenceRNNOptions,
+ StridedSliceOptions,
}
enum Padding : byte { SAME, VALID }
@@ -214,6 +218,12 @@ table RNNOptions {
fused_activation_function:ActivationFunctionType;
}
+// An implementation of TensorFlow dynamic_rnn with RNNCell.
+table SequenceRNNOptions {
+ time_major:bool;
+ fused_activation_function:ActivationFunctionType;
+}
+
// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
table FullyConnectedOptions {
fused_activation_function:ActivationFunctionType;
@@ -256,8 +266,6 @@ table LSTMOptions {
}
table ResizeBilinearOptions {
- new_height:int;
- new_width:int;
}
// A call operation options
@@ -267,8 +275,6 @@ table CallOptions {
}
table PadOptions {
- before_padding:[int];
- after_padding:[int];
}
table ReshapeOptions {
@@ -332,6 +338,14 @@ table SqueezeOptions {
squeeze_dims:[int];
}
+table StridedSliceOptions {
+ begin_mask: int;
+ end_mask: int;
+ ellipsis_mask: int;
+ new_axis_mask: int;
+ shrink_axis_mask: int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index a2ec8e40e9..c04a73a2bf 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -48,6 +48,9 @@ struct SVDFOptionsT;
struct RNNOptions;
struct RNNOptionsT;
+struct SequenceRNNOptions;
+struct SequenceRNNOptionsT;
+
struct FullyConnectedOptions;
struct FullyConnectedOptionsT;
@@ -117,6 +120,9 @@ struct MeanOptionsT;
struct SqueezeOptions;
struct SqueezeOptionsT;
+struct StridedSliceOptions;
+struct StridedSliceOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -203,11 +209,13 @@ enum BuiltinOperator {
BuiltinOperator_SUB = 41,
BuiltinOperator_DIV = 42,
BuiltinOperator_SQUEEZE = 43,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ BuiltinOperator_STRIDED_SLICE = 45,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SQUEEZE
+ BuiltinOperator_MAX = BuiltinOperator_STRIDED_SLICE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[43] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -249,7 +257,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] {
BuiltinOperator_MEAN,
BuiltinOperator_SUB,
BuiltinOperator_DIV,
- BuiltinOperator_SQUEEZE};
+ BuiltinOperator_SQUEEZE,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOperator_STRIDED_SLICE};
return values;
}
@@ -298,6 +308,8 @@ inline const char **EnumNamesBuiltinOperator() {
"SUB",
"DIV",
"SQUEEZE",
+ "UNIDIRECTIONAL_SEQUENCE_LSTM",
+ "STRIDED_SLICE",
nullptr};
return names;
}
@@ -339,11 +351,13 @@ enum BuiltinOptions {
BuiltinOptions_SubOptions = 28,
BuiltinOptions_DivOptions = 29,
BuiltinOptions_SqueezeOptions = 30,
+ BuiltinOptions_SequenceRNNOptions = 31,
+ BuiltinOptions_StridedSliceOptions = 32,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SqueezeOptions
+ BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -375,7 +389,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] {
BuiltinOptions_MeanOptions,
BuiltinOptions_SubOptions,
BuiltinOptions_DivOptions,
- BuiltinOptions_SqueezeOptions};
+ BuiltinOptions_SqueezeOptions,
+ BuiltinOptions_SequenceRNNOptions,
+ BuiltinOptions_StridedSliceOptions};
return values;
}
@@ -411,6 +427,8 @@ inline const char **EnumNamesBuiltinOptions() {
"SubOptions",
"DivOptions",
"SqueezeOptions",
+ "SequenceRNNOptions",
+ "StridedSliceOptions",
nullptr};
return names;
}
@@ -579,6 +597,16 @@ struct BuiltinOptionsTraits<SqueezeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions;
};
+template <>
+struct BuiltinOptionsTraits<SequenceRNNOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions;
+};
+
+template <>
+struct BuiltinOptionsTraits<StridedSliceOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -926,6 +954,26 @@ struct BuiltinOptionsUnion {
? reinterpret_cast<const SqueezeOptionsT *>(value)
: nullptr;
}
+ SequenceRNNOptionsT *AsSequenceRNNOptions() {
+ return type == BuiltinOptions_SequenceRNNOptions
+ ? reinterpret_cast<SequenceRNNOptionsT *>(value)
+ : nullptr;
+ }
+ const SequenceRNNOptionsT *AsSequenceRNNOptions() const {
+ return type == BuiltinOptions_SequenceRNNOptions
+ ? reinterpret_cast<const SequenceRNNOptionsT *>(value)
+ : nullptr;
+ }
+ StridedSliceOptionsT *AsStridedSliceOptions() {
+ return type == BuiltinOptions_StridedSliceOptions
+ ? reinterpret_cast<StridedSliceOptionsT *>(value)
+ : nullptr;
+ }
+ const StridedSliceOptionsT *AsStridedSliceOptions() const {
+ return type == BuiltinOptions_StridedSliceOptions
+ ? reinterpret_cast<const StridedSliceOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -1886,6 +1934,77 @@ flatbuffers::Offset<RNNOptions> CreateRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
+ typedef SequenceRNNOptions TableType;
+ bool time_major;
+ ActivationFunctionType fused_activation_function;
+ SequenceRNNOptionsT()
+ : time_major(false),
+ fused_activation_function(ActivationFunctionType_NONE) {}
+};
+
+struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SequenceRNNOptionsT NativeTableType;
+ enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 };
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; }
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ verifier.EndTable();
+ }
+ SequenceRNNOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ SequenceRNNOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SequenceRNNOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SequenceRNNOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_time_major(bool time_major) {
+ fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_TIME_MAJOR,
+ static_cast<uint8_t>(time_major), 0);
+ }
+ void add_fused_activation_function(
+ ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SequenceRNNOptionsBuilder &operator=(const SequenceRNNOptionsBuilder &);
+ flatbuffers::Offset<SequenceRNNOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SequenceRNNOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false,
+ ActivationFunctionType fused_activation_function =
+ ActivationFunctionType_NONE) {
+ SequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_time_major(time_major);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
typedef FullyConnectedOptions TableType;
ActivationFunctionType fused_activation_function;
@@ -2538,26 +2657,13 @@ flatbuffers::Offset<CallOptions> CreateCallOptions(
struct PadOptionsT : public flatbuffers::NativeTable {
typedef PadOptions TableType;
- std::vector<int32_t> before_padding;
- std::vector<int32_t> after_padding;
PadOptionsT() {}
};
struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef PadOptionsT NativeTableType;
- enum { VT_BEFORE_PADDING = 4, VT_AFTER_PADDING = 6 };
- const flatbuffers::Vector<int32_t> *before_padding() const {
- return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEFORE_PADDING);
- }
- const flatbuffers::Vector<int32_t> *after_padding() const {
- return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AFTER_PADDING);
- }
bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) &&
- VerifyOffset(verifier, VT_BEFORE_PADDING) &&
- verifier.Verify(before_padding()) &&
- VerifyOffset(verifier, VT_AFTER_PADDING) &&
- verifier.Verify(after_padding()) && verifier.EndTable();
+ return VerifyTableStart(verifier) && verifier.EndTable();
}
PadOptionsT *UnPack(
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2572,14 +2678,6 @@ struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
struct PadOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
- void add_before_padding(
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_padding) {
- fbb_.AddOffset(PadOptions::VT_BEFORE_PADDING, before_padding);
- }
- void add_after_padding(
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_padding) {
- fbb_.AddOffset(PadOptions::VT_AFTER_PADDING, after_padding);
- }
explicit PadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2593,24 +2691,11 @@ struct PadOptionsBuilder {
};
inline flatbuffers::Offset<PadOptions> CreatePadOptions(
- flatbuffers::FlatBufferBuilder &_fbb,
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_padding = 0,
- flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_padding = 0) {
+ flatbuffers::FlatBufferBuilder &_fbb) {
PadOptionsBuilder builder_(_fbb);
- builder_.add_after_padding(after_padding);
- builder_.add_before_padding(before_padding);
return builder_.Finish();
}
-inline flatbuffers::Offset<PadOptions> CreatePadOptionsDirect(
- flatbuffers::FlatBufferBuilder &_fbb,
- const std::vector<int32_t> *before_padding = nullptr,
- const std::vector<int32_t> *after_padding = nullptr) {
- return tflite::CreatePadOptions(
- _fbb, before_padding ? _fbb.CreateVector<int32_t>(*before_padding) : 0,
- after_padding ? _fbb.CreateVector<int32_t>(*after_padding) : 0);
-}
-
flatbuffers::Offset<PadOptions> CreatePadOptions(
flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -3437,6 +3522,111 @@ flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct StridedSliceOptionsT : public flatbuffers::NativeTable {
+ typedef StridedSliceOptions TableType;
+ int32_t begin_mask;
+ int32_t end_mask;
+ int32_t ellipsis_mask;
+ int32_t new_axis_mask;
+ int32_t shrink_axis_mask;
+ StridedSliceOptionsT()
+ : begin_mask(0),
+ end_mask(0),
+ ellipsis_mask(0),
+ new_axis_mask(0),
+ shrink_axis_mask(0) {}
+};
+
+struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS
+ : private flatbuffers::Table {
+ typedef StridedSliceOptionsT NativeTableType;
+ enum {
+ VT_BEGIN_MASK = 4,
+ VT_END_MASK = 6,
+ VT_ELLIPSIS_MASK = 8,
+ VT_NEW_AXIS_MASK = 10,
+ VT_SHRINK_AXIS_MASK = 12
+ };
+ int32_t begin_mask() const { return GetField<int32_t>(VT_BEGIN_MASK, 0); }
+ int32_t end_mask() const { return GetField<int32_t>(VT_END_MASK, 0); }
+ int32_t ellipsis_mask() const {
+ return GetField<int32_t>(VT_ELLIPSIS_MASK, 0);
+ }
+ int32_t new_axis_mask() const {
+ return GetField<int32_t>(VT_NEW_AXIS_MASK, 0);
+ }
+ int32_t shrink_axis_mask() const {
+ return GetField<int32_t>(VT_SHRINK_AXIS_MASK, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_BEGIN_MASK) &&
+ VerifyField<int32_t>(verifier, VT_END_MASK) &&
+ VerifyField<int32_t>(verifier, VT_ELLIPSIS_MASK) &&
+ VerifyField<int32_t>(verifier, VT_NEW_AXIS_MASK) &&
+ VerifyField<int32_t>(verifier, VT_SHRINK_AXIS_MASK) &&
+ verifier.EndTable();
+ }
+ StridedSliceOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ StridedSliceOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<StridedSliceOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct StridedSliceOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_begin_mask(int32_t begin_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0);
+ }
+ void add_end_mask(int32_t end_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_END_MASK, end_mask, 0);
+ }
+ void add_ellipsis_mask(int32_t ellipsis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_ELLIPSIS_MASK,
+ ellipsis_mask, 0);
+ }
+ void add_new_axis_mask(int32_t new_axis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_NEW_AXIS_MASK,
+ new_axis_mask, 0);
+ }
+ void add_shrink_axis_mask(int32_t shrink_axis_mask) {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_SHRINK_AXIS_MASK,
+ shrink_axis_mask, 0);
+ }
+ explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &);
+ flatbuffers::Offset<StridedSliceOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<StridedSliceOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0,
+ int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0,
+ int32_t shrink_axis_mask = 0) {
+ StridedSliceOptionsBuilder builder_(_fbb);
+ builder_.add_shrink_axis_mask(shrink_axis_mask);
+ builder_.add_new_axis_mask(new_axis_mask);
+ builder_.add_ellipsis_mask(ellipsis_mask);
+ builder_.add_end_mask(end_mask);
+ builder_.add_begin_mask(begin_mask);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -3716,6 +3906,16 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
? static_cast<const SqueezeOptions *>(builtin_options())
: nullptr;
}
+ const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const {
+ return builtin_options_type() == BuiltinOptions_SequenceRNNOptions
+ ? static_cast<const SequenceRNNOptions *>(builtin_options())
+ : nullptr;
+ }
+ const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const {
+ return builtin_options_type() == BuiltinOptions_StridedSliceOptions
+ ? static_cast<const StridedSliceOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -3917,6 +4117,18 @@ inline const SqueezeOptions *Operator::builtin_options_as<SqueezeOptions>()
return builtin_options_as_SqueezeOptions();
}
+template <>
+inline const SequenceRNNOptions *
+Operator::builtin_options_as<SequenceRNNOptions>() const {
+ return builtin_options_as_SequenceRNNOptions();
+}
+
+template <>
+inline const StridedSliceOptions *
+Operator::builtin_options_as<StridedSliceOptions>() const {
+ return builtin_options_as_StridedSliceOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -4841,6 +5053,51 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
return tflite::CreateRNNOptions(_fbb, _fused_activation_function);
}
+inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SequenceRNNOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SequenceRNNOptions::UnPackTo(
+ SequenceRNNOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = time_major();
+ _o->time_major = _e;
+ };
+ {
+ auto _e = fused_activation_function();
+ _o->fused_activation_function = _e;
+ };
+}
+
+inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSequenceRNNOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const SequenceRNNOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _time_major = _o->time_major;
+ auto _fused_activation_function = _o->fused_activation_function;
+ return tflite::CreateSequenceRNNOptions(_fbb, _time_major,
+ _fused_activation_function);
+}
+
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new FullyConnectedOptionsT();
@@ -5281,24 +5538,6 @@ inline void PadOptions::UnPackTo(
PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
- {
- auto _e = before_padding();
- if (_e) {
- _o->before_padding.resize(_e->size());
- for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
- _o->before_padding[_i] = _e->Get(_i);
- }
- }
- };
- {
- auto _e = after_padding();
- if (_e) {
- _o->after_padding.resize(_e->size());
- for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
- _o->after_padding[_i] = _e->Get(_i);
- }
- }
- };
}
inline flatbuffers::Offset<PadOptions> PadOptions::Pack(
@@ -5318,11 +5557,7 @@ inline flatbuffers::Offset<PadOptions> CreatePadOptions(
const flatbuffers::rehasher_function_t *__rehasher;
} _va = {&_fbb, _o, _rehasher};
(void)_va;
- auto _before_padding =
- _o->before_padding.size() ? _fbb.CreateVector(_o->before_padding) : 0;
- auto _after_padding =
- _o->after_padding.size() ? _fbb.CreateVector(_o->after_padding) : 0;
- return tflite::CreatePadOptions(_fbb, _before_padding, _after_padding);
+ return tflite::CreatePadOptions(_fbb);
}
inline ReshapeOptionsT *ReshapeOptions::UnPack(
@@ -5889,6 +6124,67 @@ inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions(
return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims);
}
+inline StridedSliceOptionsT *StridedSliceOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new StridedSliceOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void StridedSliceOptions::UnPackTo(
+ StridedSliceOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = begin_mask();
+ _o->begin_mask = _e;
+ };
+ {
+ auto _e = end_mask();
+ _o->end_mask = _e;
+ };
+ {
+ auto _e = ellipsis_mask();
+ _o->ellipsis_mask = _e;
+ };
+ {
+ auto _e = new_axis_mask();
+ _o->new_axis_mask = _e;
+ };
+ {
+ auto _e = shrink_axis_mask();
+ _o->shrink_axis_mask = _e;
+ };
+}
+
+inline flatbuffers::Offset<StridedSliceOptions> StridedSliceOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateStridedSliceOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<StridedSliceOptions> CreateStridedSliceOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const StridedSliceOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _begin_mask = _o->begin_mask;
+ auto _end_mask = _o->end_mask;
+ auto _ellipsis_mask = _o->ellipsis_mask;
+ auto _new_axis_mask = _o->new_axis_mask;
+ auto _shrink_axis_mask = _o->shrink_axis_mask;
+ return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask,
+ _ellipsis_mask, _new_axis_mask,
+ _shrink_axis_mask);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
@@ -6397,6 +6693,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
@@ -6541,6 +6845,14 @@ inline void *BuiltinOptionsUnion::UnPack(
auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
@@ -6672,6 +6984,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
auto ptr = reinterpret_cast<const SqueezeOptionsT *>(value);
return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptionsT *>(value);
+ return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<const StridedSliceOptionsT *>(value);
+ return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
@@ -6817,6 +7137,16 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
new SqueezeOptionsT(*reinterpret_cast<SqueezeOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ value = new SequenceRNNOptionsT(
+ *reinterpret_cast<SequenceRNNOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_StridedSliceOptions: {
+ value = new StridedSliceOptionsT(
+ *reinterpret_cast<StridedSliceOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -6974,6 +7304,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<SequenceRNNOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_StridedSliceOptions: {
+ auto ptr = reinterpret_cast<StridedSliceOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index 07a38c4243..0c5e00a1f2 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+#define TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
#include <list>
#include <memory>
@@ -85,4 +85,4 @@ class SimpleMemoryArena {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 12872d1123..c35a2fff3c 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -37,8 +37,8 @@ limitations under the License.
// # described above.
// buf.WriteToTensor(tensor)
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
#include <vector>
@@ -88,4 +88,4 @@ int GetStringCount(const TfLiteTensor* tensor);
StringRef GetString(const TfLiteTensor* tensor, int string_index);
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 933da11353..50e8ca75f8 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -46,6 +46,7 @@ gen_zipped_test_files(
"space_to_batch_nd.zip",
"space_to_depth.zip",
"squeeze.zip",
+ "strided_slice.zip",
"sub.zip",
"transpose.zip",
],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 6c3d31fc9a..a639351657 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -853,34 +853,55 @@ def make_fused_batch_norm_tests(zip_path):
def make_conv_tests(zip_path):
"""Make a set of tests to do convolution."""
- test_parameters = [{
- "input_shape": [[1, 3, 4, 3]],
- "filter_shape": [[1, 1, 3, 2]],
- "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
- "padding": ["SAME", "VALID"],
- "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
- }, {
- "input_shape": [[2, 14, 14, 2]],
- "filter_shape": [[6, 6, 2, 2]],
- "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
- "padding": ["SAME", "VALID"],
- "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
- }]
+ test_parameters = [
+ {
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_shape": [[1, 1, 3, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ "constant_filter": [True, False],
+ },
+ {
+ "input_shape": [[2, 14, 14, 2]],
+ "filter_shape": [[6, 6, 2, 2]],
+ "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"], # TODO(aselle): NCHW would be good
+ "constant_filter": [True, False],
+ }
+ ]
def build_graph(parameters):
+ """Build a conv graph given `parameters`."""
input_tensor = tf.placeholder(
dtype=tf.float32, name="input", shape=parameters["input_shape"])
- filter_values = create_tensor_data(np.float32, parameters["filter_shape"])
- out = tf.nn.conv2d(input_tensor, filter_values,
- strides=parameters["strides"],
- padding=parameters["padding"],
- data_format=parameters["data_format"])
- return [input_tensor], [out]
+
+ # Get filter input either as a placeholder or constants. Also get a list of
+ # the input tensors that are represented as placeholders.
+ if parameters["constant_filter"]:
+ filter_input = create_tensor_data(np.float32, parameters["filter_shape"])
+ input_tensors = [input_tensor]
+ else:
+ filter_input = tf.placeholder(
+ dtype=tf.float32, name="filter", shape=parameters["filter_shape"])
+ input_tensors = [input_tensor, filter_input]
+
+ out = tf.nn.conv2d(
+ input_tensor,
+ filter_input,
+ strides=parameters["strides"],
+ padding=parameters["padding"],
+ data_format=parameters["data_format"])
+ return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_values = create_tensor_data(np.float32, parameters["input_shape"])
- return [input_values], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_values])))
+ # Build list of input values either containing 1 tensor (input) or 2 tensors
+ # (input, filter) based on whether filter is constant or variable input.
+ values = [create_tensor_data(np.float32, parameters["input_shape"])]
+ if not parameters["constant_filter"]:
+ values.append(create_tensor_data(np.float32, parameters["filter_shape"]))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -889,45 +910,70 @@ def make_depthwiseconv_tests(zip_path):
"""Make a set of tests to do convolution."""
# Tensorflow only supports equal strides
- test_parameters = [{
- "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
- "filter_size": [[1, 1], [1, 2], [3, 3]],
- "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
- "channel_multiplier": [1, 2],
- "rate": [[1, 1]],
- "padding": ["SAME", "VALID"],
- "data_format": ["NHWC"],
- }, {
- "input_shape": [[1, 3, 4, 3]],
- "filter_size": [[1, 1]],
- "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
- "channel_multiplier": [2],
- "rate": [[2, 2]], # Only [1, 1] is supported
- "padding": ["SAME"],
- "data_format": ["NHWC"],
- }]
+ test_parameters = [
+ {
+ "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
+ "filter_size": [[1, 1], [1, 2], [3, 3]],
+ "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "channel_multiplier": [1, 2],
+ "rate": [[1, 1]],
+ "padding": ["SAME", "VALID"],
+ "data_format": ["NHWC"],
+ "constant_filter": [True, False],
+ },
+ {
+ "input_shape": [[1, 3, 4, 3]],
+ "filter_size": [[1, 1]],
+ "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "channel_multiplier": [2],
+ "rate": [[2, 2]], # Only [1, 1] is supported
+ "padding": ["SAME"],
+ "data_format": ["NHWC"],
+ "constant_filter": [True, False],
+ }
+ ]
- def build_graph(parameters):
- """Build a depthwise conv graph given `parameters`."""
+ def get_tensor_shapes(parameters):
input_shape = parameters["input_shape"]
filter_size = parameters["filter_size"]
+ filter_shape = filter_size + [
+ input_shape[3], parameters["channel_multiplier"]
+ ]
+ return [input_shape, filter_shape]
+
+ def build_graph(parameters):
+ """Build a depthwise conv graph given `parameters`."""
+ input_shape, filter_shape = get_tensor_shapes(parameters)
input_tensor = tf.placeholder(
dtype=tf.float32, name="input", shape=input_shape)
- filter_shape = filter_size + [
- input_shape[3], parameters["channel_multiplier"]]
- filter_values = create_tensor_data(np.float32, filter_shape)
+
+ # Get filter input either as a placeholder or constants. Also get a list of
+ # the input tensors that are represented as placeholders.
+ if parameters["constant_filter"]:
+ filter_input = create_tensor_data(np.float32, filter_shape)
+ input_tensors = [input_tensor]
+ else:
+ filter_input = tf.placeholder(
+ dtype=tf.float32, name="filter", shape=filter_shape)
+ input_tensors = [input_tensor, filter_input]
+
out = tf.nn.depthwise_conv2d(
- input_tensor, filter_values,
+ input_tensor,
+ filter_input,
strides=parameters["strides"],
rate=parameters["rate"],
padding=parameters["padding"],
data_format=parameters["data_format"])
- return [input_tensor], [out]
+ return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_values = create_tensor_data(np.float32, parameters["input_shape"])
- return [input_values], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_values])))
+ # Build list of input values either containing 1 tensor (input) or 2 tensors
+ # (input, filter) based on whether filter is constant or variable input.
+ input_shape, filter_shape = get_tensor_shapes(parameters)
+ values = [create_tensor_data(np.float32, input_shape)]
+ if not parameters["constant_filter"]:
+ values.append(create_tensor_data(np.float32, filter_shape))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -978,32 +1024,49 @@ def make_fully_connected_tests(zip_path):
"shape2": [[3, 3]],
"transpose_a": [True, False],
"transpose_b": [True, False],
+ "constant_filter": [True, False],
}, {
"shape1": [[4, 4], [1, 4], [4]],
"shape2": [[4, 4], [4, 1], [4]],
"transpose_a": [False],
"transpose_b": [False],
+ "constant_filter": [True, False],
}, {
"shape1": [[40, 37]],
"shape2": [[37, 40]],
"transpose_a": [False],
"transpose_b": [False],
-
+ "constant_filter": [True, False],
}]
def build_graph(parameters):
+ """Build a matmul graph given `parameters`."""
input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1",
shape=parameters["shape1"])
- input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
+
+ # Get input_tensor2 either as a placeholder or constants. Also get a list of
+ # the input tensors that are represented as placeholders.
+ if parameters["constant_filter"]:
+ input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
+ input_tensors = [input_tensor1]
+ else:
+ input_tensor2 = tf.placeholder(
+ dtype=tf.float32, name="input2", shape=parameters["shape2"])
+ input_tensors = [input_tensor1, input_tensor2]
+
out = tf.matmul(input_tensor1, input_tensor2,
transpose_a=parameters["transpose_a"],
transpose_b=parameters["transpose_b"])
- return [input_tensor1], [out]
+ return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"])
- return [input_values1], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_values1])))
+ # Build list of input values either containing 1 tensor (input_values1) or 2
+ # tensors (input_values1, input_values2) based on whether the second input
+ # is a constant or variable input.
+ values = [create_tensor_data(np.float32, shape=parameters["shape1"])]
+ if not parameters["constant_filter"]:
+ values.append(create_tensor_data(np.float32, parameters["shape2"]))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -1078,28 +1141,43 @@ def make_pad_tests(zip_path):
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
"paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0],
[0, 0], [2, 3]]],
+ "constant_paddings": [True, False],
},
# Non-4D use case.
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 2], [0, 1, 2]],
"paddings": [[[0, 1], [2, 3]]],
+ "constant_paddings": [True, False],
},
]
def build_graph(parameters):
+ """Build a pad graph given `parameters`."""
input_tensor = tf.placeholder(
dtype=parameters["dtype"],
name="input",
shape=parameters["input_shape"])
- out = tf.pad(input_tensor, paddings=parameters["paddings"])
- return [input_tensor], [out]
+
+ # Get paddings as either a placeholder or constants.
+ if parameters["constant_paddings"]:
+ paddings = parameters["paddings"]
+ input_tensors = [input_tensor]
+ else:
+ shape = [len(parameters["paddings"]), 2]
+ paddings = tf.placeholder(dtype=tf.int32, name="padding", shape=shape)
+ input_tensors = [input_tensor, paddings]
+
+ out = tf.pad(input_tensor, paddings=paddings)
+ return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
- input_values = create_tensor_data(parameters["dtype"],
- parameters["input_shape"])
- return [input_values], sess.run(
- outputs, feed_dict=dict(zip(inputs, [input_values])))
+ values = [
+ create_tensor_data(parameters["dtype"], parameters["input_shape"])
+ ]
+ if not parameters["constant_paddings"]:
+ values.append(np.array(parameters["paddings"]))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
@@ -1361,6 +1439,10 @@ def make_squeeze_tests(zip_path):
"dtype": [tf.int32, tf.float32, tf.int64],
"input_shape": [[1]],
"axis": [None, [], [0], [-1]],
+ }, {
+ "dtype": [tf.int32, tf.float32, tf.int64],
+ "input_shape": [[1, 1, 1, 1, 1]],
+ "axis": [None, [], [0], [3, 0], [-2, 0, 3, 2]],
}]
def build_graph(parameters):
@@ -1380,6 +1462,97 @@ def make_squeeze_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_strided_slice_tests(zip_path):
+ """Make a set of tests to do strided_slice."""
+
+ # TODO(soroosh): add test/support for uint8.
+ test_parameters = [
+ # 4-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[12, 2, 2, 5]],
+ "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
+ "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
+ "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]],
+ "begin_mask": [None, 0, 1, 2, 8],
+ "end_mask": [None, 0, 1, 2, 8],
+ },
+ # 2-D
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[2, 3]],
+ "begin": [[0, 0], [1, 0]],
+ "end": [[2, 3], [2, 2]],
+ "strides": [None, [1, 1], [2, 2]],
+ "begin_mask": [None, 0, 1, 2],
+ "end_mask": [None, 0, 1, 2],
+ },
+ # Negative strides
+ {
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ "index_type": [tf.int32],
+ "input_shape": [[2, 3]],
+ "begin": [[0, -1]],
+ "end": [[2, -3]],
+ "strides": [[1, -1]],
+ "begin_mask": [None, 0, 1, 2],
+ "end_mask": [None, 0, 1, 2],
+ },
+ ]
+
+ def build_graph(parameters):
+ """Build graph for stride_slice test."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ begin = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="begin",
+ shape=[len(parameters["input_shape"])])
+ end = tf.placeholder(
+ dtype=parameters["index_type"],
+ name="end",
+ shape=[len(parameters["input_shape"])])
+ strides = (
+ tf.placeholder(
+ dtype=parameters["index_type"],
+ name="strides",
+ shape=[len(parameters["input_shape"])])
+ if parameters["strides"] is not None else None)
+ tensors = [input_tensor, begin, end]
+ if strides is not None:
+ tensors.append(strides)
+ out = tf.strided_slice(
+ input_tensor,
+ begin,
+ end,
+ strides,
+ begin_mask=parameters["begin_mask"],
+ end_mask=parameters["end_mask"])
+ return tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build inputs for stride_slice test."""
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
+ begin_values = np.array(parameters["begin"]).astype(index_type)
+ end_values = np.array(parameters["end"]).astype(index_type)
+ stride_values = (
+ np.array(parameters["strides"]).astype(index_type)
+ if parameters["strides"] is not None else None)
+ values = [input_values, begin_values, end_values]
+ if stride_values is not None:
+ values.append(stride_values)
+
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
return tf.sqrt(tf.nn.avg_pool(
@@ -1438,6 +1611,7 @@ def main(unused_args):
"transpose.zip": make_transpose_tests,
"mean.zip": make_mean_tests,
"squeeze.zip": make_squeeze_tests,
+ "strided_slice.zip": make_strided_slice_tests,
}
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index c8a6e07abd..41652a07d2 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -48,47 +48,51 @@ tensorflow::Env* env = tensorflow::Env::Default();
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
// Add doesn't support broadcasting.
- {R"(adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
- {R"(mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
- {R"(diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
- {R"(suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(^\/adda.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(^\/mula.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(^\/diva.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
+ {R"(^\/suba.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"},
// Add only supports float32. (and "constant" tests use Add)
- {R"(adda.*int32)", "68808744"},
- {R"(constant.*int32)", "68808744"},
- {R"(mul.*int32)", "68808744"},
- {R"(div.*int32)", "68808744"},
- {R"(sub.*int32)", "68808744"},
+ {R"(^\/adda.*int32)", "68808744"},
+ {R"(^\/constant.*int32)", "68808744"},
+ {R"(^\/mul.*int32)", "68808744"},
+ {R"(^\/div.*int32)", "68808744"},
+ {R"(^\/sub.*int32)", "68808744"},
// Pad only supports 4D tensors.
- {R"(paddtype=.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
+ {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
"70527055"},
// L2Norm only supports tensors with 4D or fewer.
- {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
+ {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
// SpaceToBatch only supports 4D tensors.
- {R"(space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
+ {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
// L2Norm only works for dim=-1.
- {R"(l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
- {R"(l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
- {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
+ {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
+ {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])",
+ "67963812"},
+ {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
// ResizeBilinear looks completely incompatible with Tensorflow
- {R"(resize_bilinear)", "67964336"},
+ {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
+ {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"},
+ {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"},
+ {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"},
// Transpose only supports 1D-4D input tensors.
- {R"(transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"},
+ {R"(^\/transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"},
};
// Allows test data to be unzipped into a temporary directory and makes
@@ -263,6 +267,7 @@ INSTANTIATE_TESTS(div)
INSTANTIATE_TESTS(transpose)
INSTANTIATE_TESTS(mean)
INSTANTIATE_TESTS(squeeze)
+INSTANTIATE_TESTS(strided_slice)
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h
index 78ef7e2cbe..e2bc408214 100644
--- a/tensorflow/contrib/lite/testing/message.h
+++ b/tensorflow/contrib/lite/testing/message.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
#include <memory>
#include <string>
@@ -79,4 +79,4 @@ class Message {
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc
index 7c371f2bd4..0caef0fe22 100644
--- a/tensorflow/contrib/lite/testing/parse_testdata.cc
+++ b/tensorflow/contrib/lite/testing/parse_testdata.cc
@@ -18,6 +18,7 @@ limitations under the License.
// ASCII file.
#include "tensorflow/contrib/lite/testing/parse_testdata.h"
+#include <cinttypes>
#include <cmath>
#include <cstdint>
#include <cstdio>
@@ -218,8 +219,8 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
int32_t computed = data[idx];
int32_t reference = example.outputs[0].flat_data[idx];
if (std::abs(computed - reference) > 0) {
- fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n",
- i, idx, data[idx], example.outputs[0].flat_data[idx]);
+ fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
+ i, idx, computed, reference);
return kTfLiteError;
}
}
@@ -231,8 +232,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
int64_t reference = example.outputs[0].flat_data[idx];
if (std::abs(computed - reference) > 0) {
fprintf(stderr,
- "output[%zu][%zu] did not match %ld vs reference %f\n", i,
- idx, data[idx], example.outputs[0].flat_data[idx]);
+ "output[%zu][%zu] did not match %" PRId64
+ " vs reference %" PRId64 "\n",
+ i, idx, computed, reference);
return kTfLiteError;
}
}
diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h
index 90839fe245..7ebf362eb9 100644
--- a/tensorflow/contrib/lite/testing/parse_testdata.h
+++ b/tensorflow/contrib/lite/testing/parse_testdata.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#define TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
#include <vector>
#include "tensorflow/contrib/lite/interpreter.h"
@@ -71,4 +71,4 @@ bool ParseAndRunTests(std::istream* input, TestRunner* test_runner);
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_
diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h
index cfc1e929e9..428cfda4f2 100644
--- a/tensorflow/contrib/lite/testing/split.h
+++ b/tensorflow/contrib/lite/testing/split.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
#include <cstdlib>
#include <string>
@@ -83,4 +83,4 @@ inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h
index f4b26949b5..60eaafa474 100644
--- a/tensorflow/contrib/lite/testing/test_runner.h
+++ b/tensorflow/contrib/lite/testing/test_runner.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
#include <memory>
#include <string>
@@ -121,4 +121,4 @@ class TestRunner {
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 4440d4285e..25689a9fb4 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
#include <map>
@@ -59,4 +59,4 @@ class TfLiteDriver : public TestRunner {
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_
diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h
index daccf0e84a..7ed8eb96b7 100644
--- a/tensorflow/contrib/lite/testing/tokenize.h
+++ b/tensorflow/contrib/lite/testing/tokenize.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
#include <istream>
#include <string>
@@ -39,4 +39,4 @@ void Tokenize(std::istream* input, TokenProcessor* processor);
} // namespace testing
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 4d4304f022..6d20aec141 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
namespace tflite {
@@ -25,4 +25,4 @@ inline void LogToStderr() {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 967e304742..6fc7e5e3fd 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -160,6 +160,7 @@ cc_library(
],
deps = [
# Placeholder for internal file dependency.
+ "@protobuf_archive//:protobuf_headers",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -171,6 +172,8 @@ cc_library(
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
"graph_transformations/convert_pure_conv_to_depthwise.cc",
+ "graph_transformations/convert_reorder_axes.cc",
+ "graph_transformations/convert_trivial_addn_to_add.cc",
"graph_transformations/convert_trivial_transpose_to_reshape.cc",
"graph_transformations/create_im2col_arrays.cc",
"graph_transformations/dequantize.cc",
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index d4da8f5dfe..49cc1fc2aa 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -148,7 +148,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
if (!IsAllocatableTransientArray(model, array_name)) {
return 0;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(array->has_shape())
<< "Array '" << array_name << "' doesn't have a shape";
if (array->data_type == ArrayDataType::kNone) {
@@ -158,9 +158,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
LOG(FATAL)
<< "A RNN state array, " << array_name << ", still does not "
<< "have a known data type after all graph transformations have "
- << "run. That's mostly a toco bug --- sorry. For now, you can "
- << "work around this issue by adding manually_create:true in the "
- << "--rnn_state description of this RNN state.";
+ << "run.";
}
}
LOG(FATAL) << "An array, " << array_name << ", still does not "
@@ -185,7 +183,7 @@ void AllocateTransientArray(const Model& model, const string& array_name,
}
const std::size_t size =
TransientArraySize(model, array_name, transient_data_alignment);
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!array->alloc);
allocator->Allocate(size, &array->GetOrCreateAlloc());
}
@@ -197,7 +195,7 @@ void DeallocateTransientArray(const Model& model, const string& array_name,
if (!IsAllocatableTransientArray(model, array_name)) {
return;
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
CHECK(!!array->alloc);
allocator->Deallocate(*array->alloc);
}
@@ -231,7 +229,7 @@ void AllocateTransientArrays(Model* model,
// Construct a sorted map of array names, so that other layout engines can
// match exactly.
std::map<string, const Array*> ordered_arrays_map;
- for (const auto& pair : model->arrays) {
+ for (const auto& pair : model->GetArrayMap()) {
ordered_arrays_map[pair.first] = pair.second.get();
}
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
index 12d0d0498f..59d8ada1e9 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
#include "tensorflow/contrib/lite/toco/model.h"
@@ -41,4 +41,4 @@ void AllocateTransientArrays(Model* model,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index eb2d7ba916..b97a4720a7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -15,8 +15,8 @@ limitations under the License.
// This abstracts command line arguments in toco.
// Arg<T> is a parseable type that can register a default value, be able to
// parse itself, and keep track of whether it was specified.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
#include <functional>
#include <unordered_map>
@@ -147,12 +147,12 @@ class Arg<toco::StringMapList> final {
if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
const std::vector<string> inner_fields_vector =
- strings::Split(outer_member, ',');
+ absl::StrSplit(outer_member, ',');
std::unordered_map<string, string> element;
for (const string& member_field : inner_fields_vector) {
std::vector<string> outer_member_key_value =
- strings::Split(member_field, ':');
+ absl::StrSplit(member_field, ':');
if (outer_member_key_value.size() != 2) return false;
string& key = outer_member_key_value[0];
string& value = outer_member_key_value[1];
@@ -208,6 +208,7 @@ struct ParsedModelFlags {
Arg<bool> dump_graphviz_video = Arg<bool>(false);
Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
+ Arg<string> arrays_extra_info_file;
};
// Flags that describe the operation you would like to do (what conversion
@@ -232,4 +233,4 @@ struct ParsedTocoFlags {
};
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 39809216c7..c726eb6d86 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -278,8 +278,8 @@ std::vector<const Operator*> OperatorsToDump(const Model& model) {
if (last_specified) {
// Return only the part of the graph between graphviz_first_array
// and graphviz_last_array.
- CHECK(model.arrays.count(dump_options.graphviz_first_array));
- CHECK(model.arrays.count(dump_options.graphviz_last_array));
+ CHECK(model.HasArray(dump_options.graphviz_first_array));
+ CHECK(model.HasArray(dump_options.graphviz_last_array));
std::unordered_set<string> arrays_already_produced;
std::vector<string> arrays_to_produce;
arrays_to_produce.push_back(dump_options.graphviz_last_array);
@@ -336,7 +336,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
op_properties.color.TextColorString().c_str());
// Add nodes and edges for all inputs of the operator.
for (const auto& input : op.inputs) {
- if (model.arrays.count(input) == 0) {
+ if (!model.HasArray(input)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
@@ -352,7 +352,7 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
}
// Add nodes and edges for all outputs of the operator.
for (const auto& output : op.outputs) {
- if (model.arrays.count(output) == 0) {
+ if (!model.HasArray(output)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h
index 0fb28e3de8..ea5a4031c3 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.h
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
#include <string>
@@ -25,4 +25,4 @@ void DumpGraphviz(const Model& model, string* output_file_contents);
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 90fa442746..529df3cd2e 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -156,8 +156,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -177,8 +177,8 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
- CHECK(model.arrays.count(name));
- const auto& input_array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& input_array = model.GetArray(name);
const auto& input_shape = input_array.shape();
CHECK(input_array.buffer);
CHECK(input_array.buffer->type == ArrayDataType::kFloat);
@@ -193,8 +193,8 @@ void ConvertIntTensorConst(const Model& model, const string& name,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- CHECK(model.arrays.count(name));
- const auto& array = *model.arrays.at(name);
+ CHECK(model.HasArray(name));
+ const auto& array = model.GetArray(name);
auto* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
@@ -324,7 +324,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_array_name =
WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_array_name);
@@ -361,7 +361,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
// We need to convert that to H x W x InputDepth x Multiplier.
// That's only a matter of constructing a Dims object; the actual
// array layout is the same.
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& src_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
const auto& src_weights_array = model.GetArray(src_weights_name);
@@ -404,7 +404,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
biasadd_op->add_input(conv_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
+ CHECK(model.HasArray(src_op.inputs[2]));
const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]);
const auto& bias_array = model.GetArray(bias_name);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
@@ -469,10 +469,10 @@ void ConvertFullyConnectedOperator(const Model& model,
(*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*matmul_op->mutable_attr())["transpose_a"].set_b(false);
(*matmul_op->mutable_attr())["transpose_b"].set_b(false);
- CHECK(model.arrays.count(src_op.inputs[1]));
+ CHECK(model.HasArray(src_op.inputs[1]));
const string& fc_weights_name =
WalkUpToConstantArray(model, src_op.inputs[1]);
- const auto& fc_weights_array = *model.arrays.at(fc_weights_name);
+ const auto& fc_weights_array = model.GetArray(fc_weights_name);
const auto& fc_weights_shape = fc_weights_array.shape();
CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
@@ -492,8 +492,8 @@ void ConvertFullyConnectedOperator(const Model& model,
biasadd_op->add_input(matmul_output);
biasadd_op->add_input(src_op.inputs[2]);
(*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
- CHECK(model.arrays.count(src_op.inputs[2]));
- const auto& bias_array = *model.arrays.at(src_op.inputs[2]);
+ CHECK(model.HasArray(src_op.inputs[2]));
+ const auto& bias_array = model.GetArray(src_op.inputs[2]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
@@ -519,6 +519,18 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op,
(*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
+void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ auto* add_op = tensorflow_graph->add_node();
+ add_op->set_op("AddN");
+ add_op->set_name(src_op.outputs[0]);
+ for (const auto& input : src_op.inputs) {
+ *add_op->add_input() = input;
+ }
+ (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
+ (*add_op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
void ConvertMulOperator(const Model& model, const MulOperator& src_op,
GraphDef* tensorflow_graph) {
auto* add_op = tensorflow_graph->add_node();
@@ -625,7 +637,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
*reshape_op->add_input() = softmax_size;
(*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
- const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape();
+ const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
int32 flattened_size = 1;
for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
flattened_size *= input_shape.dims(i);
@@ -1013,8 +1025,8 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Op names have been chosen to match the tf.slim LSTM naming
// as closely as possible.
const int axis =
- model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
- ->shape()
+ model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
+ .shape()
.dimensions_count() -
1;
// Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
@@ -1033,9 +1045,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Write weights
const string weights_output = base + "weights";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
const auto& weights_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Convert 4D FullyConnected weights into 2D matrix
const auto& weights_shape = weights_array.shape();
CHECK_EQ(weights_shape.dimensions_count(), 2);
@@ -1059,9 +1071,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Write biases
const string biases_output = base + "biases";
- CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
+ CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
const auto& bias_array =
- *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
+ model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]);
// TODO(b/62904716) Bias arrays should be 1-D, and used directly.
Shape bias_shape_1d = bias_array.shape();
UnextendShape(&bias_shape_1d, 1);
@@ -1406,6 +1418,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kAdd) {
ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kAddN) {
+ ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kMul) {
ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
tensorflow_graph);
@@ -1557,7 +1572,7 @@ void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
- const auto& state_array = *model.arrays.at(name);
+ const auto& state_array = model.GetArray(name);
if (state_array.has_shape()) {
const auto& state_shape = state_array.shape();
const int kDims = state_shape.dimensions_count();
@@ -1574,7 +1589,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
GraphDef* tensorflow_graph) {
for (const auto& input_array : model.flags.input_arrays()) {
AddPlaceholder(input_array.name(),
- model.arrays.at(input_array.name())->data_type,
+ model.GetArray(input_array.name()).data_type,
tensorflow_graph);
}
for (const auto& rnn_state : model.flags.rnn_states()) {
@@ -1588,7 +1603,7 @@ void ExportTensorFlowGraphDefImplementation(const Model& model,
// by the above operators export. It's important that this comes
// after, as some operators need to export arrays that they reference
// in a specific way, rather than in the generic way done below.
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& array_name = array_pair.first;
const auto& array = *array_pair.second;
if (array.buffer) {
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h
index eca9774576..79682153a8 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
#include <string>
#include "tensorflow/contrib/lite/toco/model.h"
@@ -24,4 +24,4 @@ void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents);
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h
index 0e999001e0..eb81e90faf 100644
--- a/tensorflow/contrib/lite/toco/format_port.h
+++ b/tensorflow/contrib/lite/toco/format_port.h
@@ -16,8 +16,8 @@ limitations under the License.
// and util::format::AppendF. Unfortunately, type safety is not as good as a
// a full C++ example.
// TODO(aselle): When absl adds support for StrFormat, use that instead.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
#include "tensorflow/contrib/lite/toco/toco_types.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -74,4 +74,4 @@ inline string StringF(const char* fmt, Args&&... args) {
} // namespace port
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index 4776741ab9..5e07795223 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -229,7 +229,7 @@ additional information about the multiple input arrays:
well-formed quantized representation of these graphs. Such graphs should be
fixed, but as a temporary work-around, setting this
reorder_across_fake_quant flag allows the converter to perform necessary
- graph transformaitons on them, at the cost of no longer faithfully matching
+ graph transformations on them, at the cost of no longer faithfully matching
inference and training arithmetic.
### Logging flags
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 3bde9b0169..56f48d47de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -35,7 +35,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(expand_op->inputs.size(), 2);
CHECK_EQ(expand_op->outputs.size(), 1);
- const auto& input_array = *model->arrays[expand_op->inputs[0]];
+ const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return false;
@@ -46,7 +46,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
return false;
}
- const auto& axis_array = *model->arrays[expand_op->inputs[1]];
+ const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
return false;
@@ -86,7 +86,7 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
if (IsDiscardableArray(*model, axis_array_name) &&
CountOpsWithInput(*model, axis_array_name) == 1 &&
!GetOpWithOutput(*model, axis_array_name)) {
- model->arrays.erase(axis_array_name);
+ model->EraseArray(axis_array_name);
}
// Replace the operator in the graph.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index bf454c40c7..d38db85280 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -58,7 +58,7 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
depthwiseconv_op->outputs = {conv_op->outputs[0]};
if (conv_op->outputs.size() > 1) {
// delete the im2col array.
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
}
depthwiseconv_op->fused_activation_function =
conv_op->fused_activation_function;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc
new file mode 100644
index 0000000000..0d274fc687
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// Creates a Reshape operator from ReorderAxes operator.
+TensorFlowReshapeOperator* CreateReshapeFromReorderAxes(
+ Model* model, ReorderAxesOperator* reorder_op, const Shape& input_shape) {
+ auto* reshape_op = new TensorFlowReshapeOperator;
+
+ // Copy inputs and outputs to Reshape.
+ reshape_op->inputs.push_back(reorder_op->inputs[0]);
+ reshape_op->outputs = reorder_op->outputs;
+
+ // Create reshape dimensions based on input shape. Conversion from
+ // ReorderAxes to Reshape requires a 4D input shape.
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ std::vector<int> reshape_dims = {1, input_shape.dims(0), input_shape.dims(1),
+ input_shape.dims(3) * input_shape.dims(2)};
+
+ // Create a new input array for Reshape.
+ string reshape_array_name =
+ AvailableArrayName(*model, reshape_op->outputs[0]);
+ reshape_op->inputs.push_back(reshape_array_name);
+
+ Array& reshape_array = model->GetOrCreateArray(reshape_array_name);
+ *(reshape_array.mutable_shape()->mutable_dims()) = {
+ 1, static_cast<int>(reshape_dims.size())};
+ reshape_array.data_type = ArrayDataType::kInt32;
+ auto& reshape_buffer =
+ reshape_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ reshape_buffer.data = reshape_dims;
+
+ return reshape_op;
+}
+
+// Creates a Transpose operator from ReorderAxes operator.
+TransposeOperator* CreateTransposeFromReorderAxes(
+ Model* model, ReorderAxesOperator* reorder_op, const Shape& input_shape,
+ const AxesOrder& input_axes_order, const AxesOrder& output_axes_order) {
+ auto* transpose_op = new TransposeOperator;
+
+ // Copy inputs and outputs to Transpose.
+ transpose_op->inputs.push_back(reorder_op->inputs[0]);
+ transpose_op->outputs = reorder_op->outputs;
+
+ // Create permutations data based on input and output axes order.
+ std::vector<int> permutations_data;
+ GetShuffleShape(input_axes_order, output_axes_order, &permutations_data);
+
+ // Create a new input permutations array for Transpose.
+ string perm_array_name = AvailableArrayName(*model, transpose_op->outputs[0]);
+ transpose_op->inputs.push_back(perm_array_name);
+
+ Array& perm_array = model->GetOrCreateArray(perm_array_name);
+ *(perm_array.mutable_shape()->mutable_dims()) = {
+ static_cast<int>(permutations_data.size())};
+ perm_array.data_type = ArrayDataType::kInt32;
+ auto& perm_buffer = perm_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ perm_buffer.data = permutations_data;
+
+ return transpose_op;
+}
+
+// Converts ReorderAxes into Transpose and Reshape which are compatible with the
+// TFLite interpreter.
+bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
+ auto reorder_it = model->operators.begin() + op_index;
+ if (reorder_it->get()->type != OperatorType::kReorderAxes) return false;
+
+ auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
+ CHECK_EQ(reorder_op->inputs.size(), 1);
+ CHECK_EQ(reorder_op->outputs.size(), 1);
+
+ const auto& input_array_name = reorder_op->inputs[0];
+ const auto& output_array_name = reorder_op->outputs[0];
+ auto& input_array = model->GetArray(input_array_name);
+ auto& output_array = model->GetArray(output_array_name);
+
+ // Get input array. If kFakeQuant is the input into ReorderAxes, get the input
+ // array passed into kFakeQuant. kFakeQuant op is dropped when possible.
+ string constant_input_array_name = input_array_name;
+ if (!input_array.buffer) {
+ const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
+ if (op_producing_input &&
+ op_producing_input->type == OperatorType::kFakeQuant) {
+ constant_input_array_name = op_producing_input->inputs[0];
+ }
+ }
+
+ // Yield if input array contains constants or if output array size has not
+ // been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will
+ // be merged into a constant array when possible.
+ if (IsConstantParameterArray(*model, constant_input_array_name)) return false;
+ if (!output_array.has_shape()) return false;
+
+ const auto input_axes_order = reorder_op->input_axes_order;
+ const auto output_axes_order = reorder_op->output_axes_order;
+ const Shape input_shape = input_array.shape();
+
+ // Creates a Reshape or Transpose operator depending on the conversion.
+ if (input_axes_order == AxesOrder::kHWIM &&
+ output_axes_order == AxesOrder::k1HWO) {
+ // Add Reshape operator into the graph. This special case is not just a
+ // permutation. The input dimensions get merged into 3 dimensions while the
+ // order of the elements does not change.
+ auto* reshape_op =
+ CreateReshapeFromReorderAxes(model, reorder_op, input_shape);
+ const auto reshape_it = model->operators.emplace(reorder_it, reshape_op);
+ reorder_it = reshape_it + 1;
+ } else {
+ // Add Transpose operator into the graph.
+ auto* transpose_op = CreateTransposeFromReorderAxes(
+ model, reorder_op, input_shape, input_axes_order, output_axes_order);
+ const auto transpose_it =
+ model->operators.emplace(reorder_it, transpose_op);
+ reorder_it = transpose_it + 1;
+ }
+
+ // Remove ReorderAxes operator from the graph.
+ CHECK_EQ(reorder_it->get(), reorder_op);
+ model->operators.erase(reorder_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
new file mode 100644
index 0000000000..dcaaddbf3b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// This pass will convert an AddN operator with only 2 inputs into a regular Add
+// operator, to which more optimizations may apply.
+bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
+ auto addn_it = model->operators.begin() + op_index;
+ if (addn_it->get()->type != OperatorType::kAddN) {
+ return false;
+ }
+ AddNOperator* addn_op = static_cast<AddNOperator*>(addn_it->get());
+ CHECK_GE(addn_op->inputs.size(), 2);
+ CHECK_EQ(addn_op->outputs.size(), 1);
+
+ // We only reduce AddN with N=2 to a regular Add.
+ if (addn_op->inputs.size() != 2) {
+ return false;
+ }
+
+ // Copy inputs & outputs to regular Add.
+ auto* add_op = new AddOperator;
+ add_op->inputs.push_back(addn_op->inputs[0]);
+ add_op->inputs.push_back(addn_op->inputs[1]);
+ add_op->outputs = addn_op->outputs;
+
+ // Replace the AddN operator in the graph.
+ const auto add_it = model->operators.emplace(addn_it, add_op);
+ addn_it = add_it + 1;
+ CHECK_EQ(addn_it->get(), addn_op);
+ model->operators.erase(addn_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index a234c20924..c2b166033c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -29,7 +29,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
- const auto& output_array = *model->arrays[transpose_op->outputs[0]];
+ const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
@@ -70,7 +70,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
// Delete perm array if unused
if (IsDiscardableArray(*model, perm_array_name) &&
CountOpsWithInput(*model, perm_array_name) == 1) {
- model->arrays.erase(perm_array_name);
+ model->EraseArray(perm_array_name);
}
// Replace the operator in the graph.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 1735b51e5b..076415ece8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -35,7 +35,7 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
// We already have an im2col array
return false;
}
- const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ const auto& weights_array = model->GetArray(conv_op->inputs[1]);
if (!weights_array.has_shape()) {
// We need to yield until weights dims have been resolved, because
// from the weights dims we determine whether an im2col array is
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index b89e3f5310..498c864bde 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -53,7 +53,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
}
void ClearArrayQuantizationParams(const string& array_name, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
CHECK(array->quantization_params);
for (auto& input_array : *model->flags.mutable_input_arrays()) {
if (input_array.name() == array_name) {
@@ -77,7 +77,7 @@ void ClearArrayQuantizationParams(const string& array_name, Model* model) {
bool DequantizeArray(const string& array_name,
GraphTransformation* transformation, Model* model) {
- auto* array = model->arrays.at(array_name).get();
+ auto* array = &model->GetArray(array_name);
if (!array->quantization_params) {
return false;
}
@@ -214,7 +214,9 @@ bool Dequantize::Run(Model* model, std::size_t op_index) {
}
bool changed = false;
for (const string& array : arrays) {
- changed |= DequantizeArray(array, this, model);
+ if (!model->IsOptionalArray(array)) {
+ changed |= DequantizeArray(array, this, model);
+ }
}
return changed;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
index fea360740f..95558ef5ec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -45,7 +45,7 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
// Drop min/max inputs
for (int i = 1; i < fakequant_op->inputs.size(); i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
index a3ed6663bc..f7fd878b7e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -32,7 +32,7 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
// Drop the im2col array.
CHECK_EQ(conv_op->outputs.size(), 2);
- model->arrays.erase(conv_op->outputs[1]);
+ model->EraseArray(conv_op->outputs[1]);
conv_op->outputs.resize(1);
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
index ad4a6f9b78..88e59664ec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -91,7 +91,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
} else {
LOG(FATAL) << "Unhandled activation function type";
}
- model->arrays.erase(ac_op->inputs[0]);
+ model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index 4619d8bbee..dcbbead517 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -285,13 +285,13 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
LogName(*following_op));
- model->arrays.erase(binary_op->outputs[0]);
+ model->EraseArray(binary_op->outputs[0]);
following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
const auto& old_constant_param_name =
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index 8948653ec3..5b57178b18 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -309,7 +309,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
LOG(FATAL) << "should not get here";
}
- model->arrays.erase(preceding_op->outputs[0]);
+ model->EraseArray(preceding_op->outputs[0]);
preceding_op->outputs[0] = binary_op->outputs[0];
preceding_op->fused_activation_function =
binary_op->fused_activation_function;
@@ -317,7 +317,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
binary_op->inputs[index_of_constant_input];
CHECK(IsConstantParameterArray(*model, old_constant_param_name));
if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
- model->arrays.erase(old_constant_param_name);
+ model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
index f861c4147a..6961e23690 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -31,13 +31,13 @@ namespace {
void PrintModelStats(const string& label, const Model& model) {
int quantized_arrays = 0;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (array.second->quantization_params) {
quantized_arrays++;
}
}
LOG(INFO) << label << ": " << model.operators.size() << " operators, "
- << model.arrays.size() << " arrays (" << quantized_arrays
+ << model.GetArrayMap().size() << " arrays (" << quantized_arrays
<< " quantized)";
}
@@ -91,14 +91,9 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
}
} while (found_new_useful_arrays);
// Erase arrays that aren't useful, and that are discardable.
- for (auto it = model->arrays.begin(); it != model->arrays.end();) {
- if (useful_arrays.count(it->first) ||
- !IsDiscardableArray(*model, it->first)) {
- ++it;
- } else {
- it = model->arrays.erase(it);
- }
- }
+ model->EraseArrays([&](const string& name) {
+ return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
+ });
// Erase operators that do not produce a useful output array.
for (auto it = model->operators.begin(); it != model->operators.end();) {
// Only need to test the first output, as we simultaneously added all of
@@ -118,8 +113,8 @@ void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
std::vector<RnnState> rnn_states_to_keep;
for (const auto& rnn_state : model->flags.rnn_states()) {
const bool dangling =
- !model->arrays.count(rnn_state.back_edge_source_array()) ||
- !model->arrays.count(rnn_state.state_array());
+ !model->HasArray(rnn_state.back_edge_source_array()) ||
+ !model->HasArray(rnn_state.state_array());
if (dangling) {
CHECK(rnn_state.discardable());
} else {
@@ -137,6 +132,7 @@ bool GraphTransformationsPass(int increment, Model* model,
CHECK(increment == 1 || increment == -1);
bool changed = false;
if (model->operators.empty()) {
+ LOG(INFO) << "Model is empty!!!";
return false;
}
int op_index = increment == 1 ? 0 : model->operators.size() - 1;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 9ec9f92c90..e11bebcd4e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
#include <cstddef>
#include <initializer_list>
@@ -114,7 +114,9 @@ void RunGraphTransformations(Model* model, const string& message,
// List of all graph transformations
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
+DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
@@ -192,4 +194,4 @@ class RemoveTrivialReshape : public GraphTransformation {
} // end namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index 01b75e37c6..419a0776a6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -150,19 +150,19 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
// Erase the subgraph that is now replaced by L2Normalization
model->operators.erase(FindOperator(model, square_op));
- model->arrays.erase(sum_op->inputs[0]);
+ model->EraseArray(sum_op->inputs[0]);
if (sum_op->inputs.size() > 1) {
- model->arrays.erase(sum_op->inputs[1]);
+ model->EraseArray(sum_op->inputs[1]);
}
model->operators.erase(FindOperator(model, sum_op));
if (add_op) {
- model->arrays.erase(add_op->inputs[0]);
- model->arrays.erase(add_op->inputs[1]);
+ model->EraseArray(add_op->inputs[0]);
+ model->EraseArray(add_op->inputs[1]);
model->operators.erase(FindOperator(model, add_op));
}
- model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->EraseArray(sqrt_or_rsqrt_op->inputs[0]);
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
- model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
index 1865416fc2..e4d52476c6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -92,8 +92,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
// Erase intermediate arrays, keeping input to square op.
- model->arrays.erase(avpool_op->inputs[0]);
- model->arrays.erase(sqrt_op->inputs[0]);
+ model->EraseArray(avpool_op->inputs[0]);
+ model->EraseArray(sqrt_op->inputs[0]);
// Erase three operators being replaced.
model->operators.erase(FindOperator(model, square_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index cfc77024e7..d36e950609 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -89,12 +89,12 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
// Erase Maximum scalar input & operator
- model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->EraseArray(maximum_op->inputs[scalar_input_index]);
model->operators.erase(FindOperator(model, maximum_op));
// Erase Minimum inputs & operator
- model->arrays.erase(minimum_op->inputs[0]);
- model->arrays.erase(minimum_op->inputs[1]);
+ model->EraseArray(minimum_op->inputs[0]);
+ model->EraseArray(minimum_op->inputs[1]);
model->operators.erase(FindOperator(model, minimum_op));
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c6f17cf319..f0d107232b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -27,7 +27,7 @@ namespace {
void SetDataTypeForAllOutputs(Model* model, Operator* op,
ArrayDataType data_type) {
for (const auto& output : op->outputs) {
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
}
} // namespace
@@ -38,7 +38,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// If the data type of some input is unknown, we need to yield.
for (const auto& input : op->inputs) {
- if (model->arrays[input]->data_type == ArrayDataType::kNone) {
+ if (!model->IsOptionalArray(input) &&
+ model->GetArray(input).data_type == ArrayDataType::kNone) {
return false;
}
}
@@ -46,7 +47,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// end if we changed anything, and return the correct boolean value.
std::unordered_map<string, ArrayDataType> old_output_data_types;
for (const auto& output : op->outputs) {
- old_output_data_types[output] = model->arrays[output]->data_type;
+ old_output_data_types[output] = model->GetArray(output).data_type;
}
// Do the actual output data types propagation.
if (op->type == OperatorType::kDequantize ||
@@ -68,18 +69,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kFill) {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
- const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
} else if (op->type == OperatorType::kCast) {
// Data type of the Cast op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* cast_op = static_cast<CastOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
} else if (op->type == OperatorType::kArgMax) {
// Data type of the ArgMax op is specified.
CHECK_EQ(op->outputs.size(), 1);
auto* argmax_op = static_cast<ArgMaxOperator*>(op);
- model->arrays[op->outputs[0]]->data_type = argmax_op->output_data_type;
+ model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
} else if (op->type == OperatorType::kRange) {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
@@ -90,7 +91,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// Otherwise use the first input
CHECK_GE(op->inputs.size(), 1);
- data_type = model->arrays[op->inputs[0]]->data_type;
+ data_type = model->GetArray(op->inputs[0]).data_type;
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, data_type);
@@ -102,7 +103,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
auto output = op->outputs[i];
auto data_type = unsupported_op->output_data_types[i];
- model->arrays[output]->data_type = data_type;
+ model->GetArray(output).data_type = data_type;
}
} else if (op->type == OperatorType::kExpandDims) {
// Yield on ExpandDim until it is converted to Reshape
@@ -110,12 +111,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
} else {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
- const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
SetDataTypeForAllOutputs(model, op, data_type);
}
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
- if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ if (old_output_data_types[output] != model->GetArray(output).data_type) {
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a939efb4db..4fb3b6ae7a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -85,7 +85,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
const string& weights_name = op.inputs[1];
- const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ const auto& weights_shape = model.GetArray(weights_name).shape();
if (op.type == OperatorType::kConv ||
op.type == OperatorType::kFullyConnected) {
return weights_shape.dims(0);
@@ -98,7 +98,7 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
bool EnsureBiasVectorShape(Model* model, Operator* op) {
const string& weights_name = op->inputs[1];
- const auto& weights_array = *model->arrays[weights_name];
+ const auto& weights_array = model->GetArray(weights_name);
// Yield until weights shape has been resolved.
if (!weights_array.has_shape()) {
return false;
@@ -107,7 +107,7 @@ bool EnsureBiasVectorShape(Model* model, Operator* op) {
if (op->inputs.size() < 3) {
return false;
}
- auto& bias_array = *model->arrays[op->inputs[2]];
+ auto& bias_array = model->GetArray(op->inputs[2]);
if (bias_array.has_shape()) {
return true;
}
@@ -126,7 +126,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -134,7 +134,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -156,7 +156,7 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
if (op->outputs.size() == 2) {
const auto& output_shape = output_array.shape();
const int input_depth = weights_shape.dims(3);
- auto& im2col_array = *model->arrays[op->outputs[1]];
+ auto& im2col_array = model->GetArray(op->outputs[1]);
im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
output_shape.dims(2),
input_depth * kheight * kwidth});
@@ -168,7 +168,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -176,7 +176,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_EQ(input_shape.dimensions_count(), 4);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -209,7 +209,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
}
void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -232,7 +232,7 @@ void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
}
void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -258,7 +258,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
void ProcessFillOperator(Model* model, FillOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
@@ -287,7 +287,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -295,7 +295,7 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
const auto& input_shape = input_array.shape();
CHECK_GE(input_shape.dimensions_count(), 1);
- const auto& weights_array = *model->arrays[op->inputs[1]];
+ const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -315,13 +315,13 @@ void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
void ProcessTensorFlowReshapeOperator(Model* model,
TensorFlowReshapeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -377,14 +377,14 @@ void ProcessTensorFlowReshapeOperator(Model* model,
}
void ProcessSimpleOperator(Model* model, Operator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
@@ -394,18 +394,40 @@ void ProcessSimpleOperator(Model* model, Operator* op) {
void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input0_array = *model->arrays[op->inputs[0]];
- const auto& input1_array = *model->arrays[op->inputs[1]];
+ const auto& input0_array = model->GetArray(op->inputs[0]);
+ const auto& input1_array = model->GetArray(op->inputs[1]);
// Yield until input dims have been resolved.
if (!input0_array.has_shape() || !input1_array.has_shape()) {
return;
}
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
&output_array);
}
+void ProcessAddNOperator(Model* model, Operator* op) {
+ // Yield until all input dims have been resolved.
+ //
+ // TODO(myenik): Since AddN does not support broadcasting, maybe we could
+ // actually use this to improve shape propagation by propagating the shape of
+ // one input to all other inputs once it is resolved instead of just the
+ // output, since all inputs must be the same size and shape for a well-formed
+ // graph.
+ for (const auto& input : op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+
+ // AddN does not support broadcasting, all inputs must be the same shape, so
+ // we just take the first input shape and apply it to the output.
+ const auto& input0_array = model->GetArray(op->inputs[0]);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(input0_array.shape());
+}
+
bool KeepDims(const Operator& op) {
switch (op.type) {
case OperatorType::kTensorFlowMin:
@@ -424,11 +446,11 @@ bool KeepDims(const Operator& op) {
void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
CHECK_LE(op->inputs.size(), 2);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
return;
}
@@ -436,7 +458,7 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = *model->arrays[op->inputs[1]];
+ const auto& reduction_array = model->GetArray(op->inputs[1]);
if (!reduction_array.buffer) {
return;
}
@@ -476,11 +498,11 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
if (op->begin.empty()) return;
// Yield until input dims have been resolved.
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
const Shape& input_shape = input_array.shape();
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
CHECK_EQ(input_shape.dims().size(), op->size.size());
@@ -500,7 +522,7 @@ void ProcessSliceOperator(Model* model, SliceOperator* op) {
void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -515,20 +537,20 @@ void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
// Yield until input dims have been resolved.
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
if (!input_array.has_shape()) {
return;
}
}
auto& output_array = model->GetArray(op->outputs[0]);
// Use 0 input as basis for output dimensions.
- const auto& first_input_array = *model->arrays[op->inputs[0]];
+ const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
for (const auto& input_name : op->inputs) {
- auto& input_array = *model->arrays[input_name];
+ auto& input_array = model->GetArray(input_name);
CHECK(input_array.has_shape());
if (input_array.shape().dimensions_count() == 0) {
continue;
@@ -548,16 +570,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
void ProcessRangeOperator(Model* model, RangeOperator* op) {
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until input dims have been resolved.
return;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
return;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
return;
}
@@ -599,7 +621,7 @@ void ProcessRangeOperator(Model* model, RangeOperator* op) {
void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
const string& input_name = op->inputs[1];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -618,13 +640,13 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
CHECK_EQ(op->outputs.size(), op->num_split);
for (const auto& output : op->outputs) {
- model->arrays[output]->copy_shape(output_shape);
+ model->GetArray(output).copy_shape(output_shape);
}
}
void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -641,7 +663,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -658,7 +680,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
const string& input_name = op->inputs[0];
- const auto& input_array = *model->arrays[input_name];
+ const auto& input_array = model->GetArray(input_name);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -679,14 +701,14 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- if (!model->arrays[op->inputs[0]]->has_shape() ||
- !model->arrays[op->inputs[1]]->has_shape()) {
+ if (!model->GetArray(op->inputs[0]).has_shape() ||
+ !model->GetArray(op->inputs[1]).has_shape()) {
return;
}
- const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+ const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
const string& output_size_name = op->inputs[1];
- const auto& output_size_array = *model->arrays[output_size_name];
+ const auto& output_size_array = model->GetArray(output_size_name);
CHECK(output_size_array.data_type == ArrayDataType::kInt32);
CHECK(output_size_array.has_shape());
const auto& output_size_shape = output_size_array.shape();
@@ -697,9 +719,9 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
}
std::vector<int32> output_shape =
output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
- input_data_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
+ output_shape[1], input_data_shape.dims(3)}));
}
void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
@@ -708,7 +730,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
const auto& input_array =
- *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
// Yield until all input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -717,7 +739,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(input_shape.dimensions_count(), 2);
const auto& prev_activ_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_activ_array.has_shape()) {
return;
@@ -726,7 +748,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(prev_activ_shape.dimensions_count(), 2);
const auto& weights_array =
- *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
// Yield until weights dims have been resolved.
if (!weights_array.has_shape()) {
return;
@@ -735,7 +757,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_EQ(weights_shape.dimensions_count(), 2);
const auto& bias_array =
- *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
// Yield until bias dims have been resolved.
if (!bias_array.has_shape()) {
return;
@@ -744,7 +766,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
CHECK_GE(bias_shape.dimensions_count(), 1);
const auto& prev_state_array =
- *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
// Yield until all input dims have been resolved.
if (!prev_state_array.has_shape()) {
return;
@@ -784,7 +806,7 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
}
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -797,8 +819,8 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& paddings_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& paddings_array_shape = paddings_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -830,13 +852,13 @@ void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
int output_height = height_with_paddings / block_height;
int output_width = width_with_paddings / block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) * block_height * block_width, output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -846,8 +868,8 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
const auto input_height = input_shape.dims(1);
const auto input_width = input_shape.dims(2);
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
+ const auto& crops_array = model->GetArray(op->inputs[2]);
const auto& block_shape_array_shape = block_shape_array.shape();
const auto& crops_array_shape = crops_array.shape();
QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
@@ -882,15 +904,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
int output_height = input_height * block_height;
int output_width = input_width * block_width;
- model->arrays[op->outputs[0]]->copy_shape(
- Shape({input_shape.dims(0) / (block_height * block_width), output_height,
- output_width, input_shape.dims(3)}));
+ model->GetArray(op->outputs[0])
+ .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
+ output_height, output_width, input_shape.dims(3)}));
}
void ProcessGatherOperator(Model* model, GatherOperator* op) {
- const auto& input_array = *model->arrays[op->inputs[0]];
- const auto& indices_array = *model->arrays[op->inputs[1]];
- auto& output_array = *model->arrays[op->outputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ const auto& indices_array = model->GetArray(op->inputs[1]);
+ auto& output_array = model->GetArray(op->outputs[0]);
// Bail if we already know the output shape.
if (output_array.has_shape()) {
@@ -924,7 +946,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
@@ -932,7 +954,7 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
if (op->left_padding.empty()) return;
CHECK_EQ(op->left_padding.size(), op->right_padding.size());
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
Shape output_shape = input_array.shape();
@@ -949,13 +971,13 @@ void ProcessPadOperator(Model* model, PadOperator* op) {
void ProcessRankOperator(Model* model, RankOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -970,13 +992,13 @@ void ProcessRankOperator(Model* model, RankOperator* op) {
void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -991,7 +1013,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
void ProcessStackOperator(Model* model, StackOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
@@ -1032,7 +1054,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
@@ -1112,12 +1134,12 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) return;
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) return;
const std::vector<int>& input_dims = input_array.shape().dims();
@@ -1136,18 +1158,18 @@ void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) return;
- auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ auto& weights_feature_array = model->GetArray(op->inputs[1]);
if (!weights_feature_array.has_shape()) return;
- const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ const auto& weights_time_array = model->GetArray(op->inputs[2]);
if (!weights_time_array.has_shape()) return;
const bool has_bias = (op->inputs.size() == 4);
if (has_bias) {
- const auto& bias_array = *model->arrays[op->inputs[3]];
+ const auto& bias_array = model->GetArray(op->inputs[3]);
if (!bias_array.has_shape()) return;
}
@@ -1164,13 +1186,13 @@ void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
}
void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// We have already run
return;
}
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
return;
@@ -1204,7 +1226,7 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
CHECK_EQ(op->inputs.size(), 2);
- const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
return;
@@ -1222,7 +1244,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
}
output_dims.push_back(1);
const string& output_name = op->outputs[0];
- auto& output_array = *model->arrays[output_name];
+ auto& output_array = model->GetArray(output_name);
if (output_array.has_shape()) {
return;
}
@@ -1236,8 +1258,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
auto* op = it->get();
std::unordered_map<string, std::vector<int>> old_output_dims;
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape()) {
- old_output_dims[output] = model->arrays[output]->shape().dims();
+ if (model->GetArray(output).has_shape()) {
+ old_output_dims[output] = model->GetArray(output).shape().dims();
}
}
@@ -1282,6 +1304,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowGreaterEqual:
ProcessSimpleBinaryOperator(model, op);
break;
+ case OperatorType::kAddN:
+ ProcessAddNOperator(model, op);
+ break;
case OperatorType::kConv:
ProcessConvOperator(model, static_cast<ConvOperator*>(op));
break;
@@ -1433,10 +1458,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
// Return true if any output dim changed, false if none changed.
// Assumption: no transformation clears an output shape, they only add shapes.
for (const auto& output : op->outputs) {
- if (model->arrays[output]->has_shape() &&
- (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ if (model->GetArray(output).has_shape() &&
+ (old_output_dims[output] != model->GetArray(output).shape().dims())) {
AddMessageF("Set shape of %s to [%s]", output,
- absl::StrJoin(model->arrays[output]->shape().dims(), ","));
+ absl::StrJoin(model->GetArray(output).shape().dims(), ","));
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 56082b965a..b973b2b813 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -412,7 +412,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
}
}
- model->arrays.erase(dequantize_op->outputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
model->operators.erase(dequantize_it);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
index 371ced388a..11f8d4b6ee 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -80,7 +80,7 @@ bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
// else.
for (int i = 1; i <= 2; i++) {
if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
- model->arrays.erase(fakequant_op->inputs[i]);
+ model->EraseArray(fakequant_op->inputs[i]);
}
}
fakequant_op->inputs.resize(1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
index 3992e7d1ef..c3b2709a33 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -51,7 +51,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
// Remove the node and its output array.
AddMessageF("Removed final %s", LogName(*dequantize_op));
- model->arrays.erase(output);
+ model->EraseArray(output);
model->operators.erase(dequantize_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
index 6add443f2d..95a50c6179 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -81,7 +81,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
// Now check if the constant operand makes this binary
// operator trivial.
const auto& constant_input_array =
- *model->arrays[binary_op->inputs[index_of_constant_input]];
+ model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
return false;
@@ -89,14 +89,14 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& constant_input_float_data =
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
bool is_trivial = false;
- if (binary_op->type != OperatorType::kAdd) {
+ if (binary_op->type == OperatorType::kAdd) {
is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
- } else if (binary_op->type != OperatorType::kSub) {
+ } else if (binary_op->type == OperatorType::kSub) {
is_trivial = index_of_constant_input == 1 &&
AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
- } else if (binary_op->type != OperatorType::kMul) {
+ } else if (binary_op->type == OperatorType::kMul) {
is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
- } else if (binary_op->type != OperatorType::kDiv) {
+ } else if (binary_op->type == OperatorType::kDiv) {
is_trivial = index_of_constant_input == 1 &&
AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
index 23a5c857e8..936854a04f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -59,7 +59,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
for (const string& input : trivial_inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
concat_op->inputs = nontrivial_inputs;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 047389f69a..587f171bbf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -124,7 +124,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
}
}
if (!is_referenced) {
- model->arrays.erase(removal_candidate);
+ model->EraseArray(removal_candidate);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
index a06181ca0b..9d448c3ee9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -54,4 +54,4 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index e6cca8acf3..aa2c293382 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -33,7 +33,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
// the model. We allow specifying an arbitrary input_array,
// treating the part of the graph leading up to it as unused.
for (const auto& output : op->outputs) {
- CHECK(model->arrays.count(output));
+ CHECK(model->HasArray(output));
// If this output is provided as the model's input array,
// then we don't need this operator to produce its contents.
if (IsInputArray(*model, output)) {
@@ -93,7 +93,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
@@ -116,7 +116,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
continue;
}
// Generic case: do delete this output array.
- model->arrays.erase(output);
+ model->EraseArray(output);
}
model->operators.erase(it);
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
index 3eb7fa3896..fb109eb91b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -121,9 +121,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
}
// Remove the old param arrays
- model->arrays.erase(bn_op->inputs[1]);
- model->arrays.erase(bn_op->inputs[2]);
- model->arrays.erase(bn_op->inputs[3]);
+ model->EraseArray(bn_op->inputs[1]);
+ model->EraseArray(bn_op->inputs[2]);
+ model->EraseArray(bn_op->inputs[3]);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
index 7777d4f543..a06919e228 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -42,7 +42,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
return false;
// Handle crops
- const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& crops_array = model->GetArray(op->inputs[2]);
if (!crops_array.has_shape()) return false;
const std::vector<int>& crops_dims = crops_array.shape().dims();
if (crops_dims.size() != 2) {
@@ -58,7 +58,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
}
// Handle block_shape
- const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& block_shape_array = model->GetArray(op->inputs[1]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index fd51df4058..5e779f6765 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -166,8 +166,9 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
void EvaluateBinaryOperatorOnConstantInputs(Model* model,
const Operator* binary_op) {
- const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
- const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+ const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
+ const auto output_data_type =
+ model->GetArray(binary_op->outputs[0]).data_type;
#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
if (inputs_data_type == InputsDataType && \
output_data_type == OutputDataType) { \
@@ -214,7 +215,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
return false;
}
- auto& output_array = *model->arrays[binary_op->outputs[0]];
+ auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
return false;
@@ -239,10 +240,10 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
// Remove the binary operator and its inputs
if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
- model->arrays.erase(binary_op->inputs[0]);
+ model->EraseArray(binary_op->inputs[0]);
}
if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
- model->arrays.erase(binary_op->inputs[1]);
+ model->EraseArray(binary_op->inputs[1]);
}
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index 9835f86398..5ac449749a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -189,7 +189,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove all the resolved arrays.
for (const string& input_name : concat_op->inputs) {
- model->arrays.erase(input_name);
+ // Check to prevent removal of shared tensors
+ if(CountOpsWithInput(*model, input_name) == 1) {
+ model->EraseArray(input_name);
+ }
}
// Remove concatenate operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 244adcc4c4..81fe37d7e0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -66,7 +66,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
output_buffer.data[i] = dst_val;
}
if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
- model->arrays.erase(fakequant_op->inputs[0]);
+ model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
index 9da51d9147..f6f95481b5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
@@ -104,11 +104,11 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
// Erase input arrays if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
// Erase the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
index 383d54aa5a..1a0ba9e2bc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
@@ -28,17 +28,17 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
auto* op = static_cast<RangeOperator*>(base_op);
CHECK_EQ(op->inputs.size(), 3);
- const auto& start_array = *model->arrays[op->inputs[0]];
+ const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& limit_array = *model->arrays[op->inputs[1]];
+ const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
}
- const auto& delta_array = *model->arrays[op->inputs[2]];
+ const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
// Yield until all input dims have been resolved.
return false;
@@ -52,7 +52,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->outputs.size(), 1);
- auto& output_array = *model->arrays[op->outputs[0]];
+ auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
@@ -87,15 +87,15 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
if (IsDiscardableArray(*model, op->inputs[1]) &&
CountOpsWithInput(*model, op->inputs[1]) == 1) {
- model->arrays.erase(op->inputs[1]);
+ model->EraseArray(op->inputs[1]);
}
if (IsDiscardableArray(*model, op->inputs[2]) &&
CountOpsWithInput(*model, op->inputs[2]) == 1) {
- model->arrays.erase(op->inputs[2]);
+ model->EraseArray(op->inputs[2]);
}
// Delete the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 35b81dd550..9ea01acd05 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -62,7 +62,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
// Delete the input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
model->operators.erase(it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
index 86c76141a4..ea0d6dc820 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
@@ -101,7 +101,7 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
if (IsDiscardableArray(*model, input) &&
CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 3976d9cbb4..a0cfc3d597 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -186,7 +186,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
// Erase input array if no longer used
if (IsDiscardableArray(*model, op->inputs[0]) &&
CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->arrays.erase(op->inputs[0]);
+ model->EraseArray(op->inputs[0]);
}
// Erase the operator
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 26ff9d887b..1cd2aff28c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -199,7 +199,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
for (const auto& input : unary_op->inputs) {
if (CountOpsWithInput(*model, input) == 1) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
AddMessageF("Resolved constant %s to the equivalent constant array",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
index b77be3f5c0..013b50ac9b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
@@ -36,7 +36,7 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
if (op->inputs.size() != 2) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& indices_array = *model->arrays[op->inputs[1]];
+ const auto& indices_array = model->GetArray(op->inputs[1]);
if (!indices_array.has_shape()) return false;
op->axis = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
index d5f5869c62..8a8e723cf7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -35,7 +35,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- const auto& array = *model->arrays[op->inputs[1]];
+ const auto& array = model->GetArray(op->inputs[1]);
if (!array.has_shape()) return false;
const std::vector<int>& dims = array.shape().dims();
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index b5093bc4c7..5c68f87f6c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -103,7 +103,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
AddMessageF("Reordered axes for array %s", input_array_name);
// Remove the op and output array.
- model->arrays.erase(output_array_name);
+ model->EraseArray(output_array_name);
model->operators.erase(reorder_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
index bed2a85bd2..2e063e3554 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -37,7 +37,7 @@ bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
if (!op->shape.empty()) return false;
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
- const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
index 1d0a2ec8f6..e760d08e5a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -36,10 +36,10 @@ bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- const auto& begin_array = *model->arrays[op->inputs[1]];
+ const auto& begin_array = model->GetArray(op->inputs[1]);
if (!begin_array.has_shape()) return false;
- const auto& size_array = *model->arrays[op->inputs[2]];
+ const auto& size_array = model->GetArray(op->inputs[2]);
if (!size_array.has_shape()) return false;
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
index a73f16735c..dad6aceccf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
@@ -45,7 +45,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
return false;
// Handle paddings.
- const auto& paddings_array = *model->arrays[op->inputs[paddings_index]];
+ const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
if (!paddings_array.has_shape()) return false;
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
if (paddings_dims.size() != 2) {
@@ -61,7 +61,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
}
// Handle block_shape.
- const auto& block_shape_array = *model->arrays[op->inputs[block_shape_index]];
+ const auto& block_shape_array =
+ model->GetArray(op->inputs[block_shape_index]);
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index dbe69adcbd..7e8b249b07 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -31,13 +31,17 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->inputs.size(), 4);
- const auto& start_array = *model->arrays[op->inputs[1]];
+ const auto& start_array = model->GetArray(op->inputs[1]);
if (!start_array.has_shape()) return false;
+ if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
+ // Only 1-4D arrays are supported for now.
+ return false;
+ }
- const auto& stop_array = *model->arrays[op->inputs[2]];
+ const auto& stop_array = model->GetArray(op->inputs[2]);
if (!stop_array.has_shape()) return false;
- const auto& stride_array = *model->arrays[op->inputs[3]];
+ const auto& stride_array = model->GetArray(op->inputs[3]);
if (!stride_array.has_shape()) return false;
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
index c6723a880e..5c0c1e3478 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -75,7 +75,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
// Remove the axis array if it is not used by anything else.
if (CountOpsWithInput(*model, axis_name) == 1) {
- model->arrays.erase(axis_name);
+ model->EraseArray(axis_name);
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index bea7487051..ad1e56888e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -69,7 +69,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
LogName(*matmul_op), LogName(*fc_op));
const auto& previous_op_output = previous_op->outputs[0];
if (CountOpsWithInput(*model, previous_op_output) == 1) {
- model->arrays.erase(previous_op_output);
+ model->EraseArray(previous_op_output);
}
CHECK_EQ(previous_op->inputs.size(), 2);
fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
@@ -78,7 +78,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
const auto& previous_op_shape = previous_op->inputs[1];
if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
!GetOpWithOutput(*model, previous_op_shape)) {
- model->arrays.erase(previous_op_shape);
+ model->EraseArray(previous_op_shape);
}
model->operators.erase(previous_op_it);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
index cfa5ce0716..477e7f13da 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -55,7 +55,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
// Remove the node and its output array.
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
- model->arrays.erase(merge_op->outputs[0]);
+ model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
index 150cf53da3..a418073441 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -103,7 +103,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
// Remove the output arrays if they are now unused.
for (int i = 0; i < 2; i++) {
if (!GetOpWithInput(*model, switch_op->outputs[i])) {
- model->arrays.erase(switch_op->outputs[i]);
+ model->EraseArray(switch_op->outputs[i]);
}
}
// Remove input arrays if they are only used by the switch itself and aren't
@@ -111,7 +111,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
for (const auto& input : switch_op->inputs) {
if (CountOpsWithInput(*model, input) == 1 &&
!GetOpWithOutput(*model, input)) {
- model->arrays.erase(input);
+ model->EraseArray(input);
}
}
// Remove the switch node itself.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
index 9f7e7c42a2..1ddf54c778 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
@@ -45,10 +45,10 @@ void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
model->operators.erase(tile_it);
if (!CountOpsWithInput(*model, tile_multiplier_array) &&
!GetOpWithOutput(*model, tile_multiplier_array)) {
- model->arrays.erase(tile_multiplier_array);
+ model->EraseArray(tile_multiplier_array);
}
if (!CountOpsWithInput(*model, tile_output_array)) {
- model->arrays.erase(tile_output_array);
+ model->EraseArray(tile_output_array);
}
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
index 12d966b261..a657ee00af 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
@@ -35,7 +35,7 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
// Handling perm.
- const auto& perm_array = *model->arrays[op->inputs[1]];
+ const auto& perm_array = model->GetArray(op->inputs[1]);
if (!perm_array.has_shape()) return false;
const std::vector<int>& perm_dims = perm_array.shape().dims();
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
index a14016e8e2..3a1d175b98 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -19,7 +19,6 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-//#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -168,11 +167,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
@@ -187,11 +186,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
@@ -206,11 +205,11 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
- EXPECT_THAT(model.arrays.size(), 5);
+ EXPECT_THAT(model.GetArrayMap().size(), 5);
(*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
- EXPECT_THAT(model.arrays.size(), 1);
+ EXPECT_THAT(model.GetArrayMap().size(), 1);
- auto& concatenated_array = (*model.arrays.begin()).second;
+ auto& concatenated_array = (*model.GetArrayMap().begin()).second;
EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
ElementsAreArray(ArrayFloatNear(
{0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
index 4e273343df..2c7046c8c7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -63,7 +63,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op->outputs = op->outputs;
const string& tmp_array_name =
AvailableArrayName(*model, op->outputs[0] + "_unfused");
- CHECK(!model->arrays.count(tmp_array_name));
+ CHECK(!model->HasArray(tmp_array_name));
model->GetOrCreateArray(tmp_array_name);
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 995e9d67ca..ca378af4c5 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -696,6 +696,19 @@ void ConvertAddOperator(const NodeDef& node,
model->operators.emplace_back(op);
}
+void ConvertAddNOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "AddN");
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ auto* op = new AddNOperator;
+ for (int i = 0; i < num_inputs; ++i) {
+ op->inputs.push_back(node.input(i));
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op);
+}
+
void ConvertMulOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1179,6 +1192,8 @@ void ConvertStridedSliceOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "StridedSlice");
+ // TODO(soroosh): The 4th input (strides) should be e optional, to be
+ // consistent with TF.
CheckInputsCount(node, tf_import_flags, 4);
auto* op = new StridedSliceOperator;
@@ -1652,7 +1667,7 @@ void StripCaretFromArrayNames(Model* model) {
output = string(absl::StripPrefix(output, "^"));
}
}
- for (auto& array : model->arrays) {
+ for (auto& array : model->GetArrayMap()) {
if (absl::StartsWith(array.first, "^")) {
LOG(FATAL) << "What?";
}
@@ -1860,6 +1875,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
ConvertSquareOperator(node, tf_import_flags, model);
} else if (node.op() == "Add") {
ConvertAddOperator(node, tf_import_flags, model);
+ } else if (node.op() == "AddN") {
+ ConvertAddNOperator(node, tf_import_flags, model);
} else if (node.op() == "Mul") {
ConvertMulOperator(node, tf_import_flags, model);
} else if (node.op() == "Sub") {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 312e3b8f17..2177872334 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
#include <memory>
#include <string>
@@ -39,4 +39,4 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 7b2235e275..d1af371fd4 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
#include <initializer_list>
#include <memory>
@@ -32,6 +32,7 @@ enum class OperatorType {
kNone,
// General-purpose neural network operators.
kAdd,
+ kAddN,
kAveragePool,
kBatchNormalization,
kConv,
@@ -559,6 +560,16 @@ struct AddOperator : Operator {
AddOperator() : Operator(OperatorType::kAdd) {}
};
+// Element-wise addition operator for N inputs.
+//
+// Inputs:
+// inputs[i]: The i-th array to add together to form the output.
+//
+// TensorFlow equivalent: AddN
+struct AddNOperator : Operator {
+ AddNOperator() : Operator(OperatorType::kAddN) {}
+};
+
// Concatenation operator: concatenates its inputs
// along the axis.
//
@@ -738,6 +749,9 @@ struct PadOperator : Operator {
//
// Inputs:
// inputs[0]: required: the input array
+// inputs[1]: required: the begin array
+// inputs[2]: required: the end array
+// inputs[3]: optional: the strides array
//
// TensorFlow equivalent: StridedSlice
struct StridedSliceOperator : Operator {
@@ -1521,29 +1535,54 @@ struct Array {
// Our Model struct, represents an entire model (our "top-level" struct).
// Owns everything.
-struct Model {
+class Model {
+ public:
+ using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>;
+
+ bool HasArray(const string& name) const { return arrays.count(name) > 0; }
Array& GetArray(const string& name) const {
- DCHECK(arrays.count(name));
+ DCHECK(HasArray(name));
return *arrays.at(name);
}
Array& GetOrCreateArray(const string& name) {
- if (!arrays.count(name)) {
+ // Make sure name is not used by an optional array
+ DCHECK(!optional_arrays.count(name));
+ if (!HasArray(name)) {
Array* ptr = new Array;
arrays[name] = std::unique_ptr<Array>(ptr);
}
Array& result = GetArray(name);
return result;
}
+ void CreateOptionalArray(const string& name) {
+ DCHECK(!arrays.count(name) && !optional_arrays.count(name));
+ optional_arrays.insert(name);
+ }
+ bool IsOptionalArray(const string& name) const {
+ return optional_arrays.count(name);
+ }
+
+ // Note that this invalidates all array iterators.
+ void EraseArray(const string& name) { arrays.erase(name); }
+ void EraseArrays(std::function<bool(const string&)> discardable) {
+ for (auto it = arrays.begin(); it != arrays.end();) {
+ if (discardable(it->first)) {
+ it = arrays.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+ const ArrayMap& GetArrayMap() const { return arrays; }
+
+ // Optional arrays are used for optional tensors,
+ // these tensors do not have data, but with reserved names as op inputs.
+ std::set<string> optional_arrays;
// The list of operators. Notice how it's a list of unique_ptr's, implying
// that the Model is what owns Operator's and keeps them alive.
std::vector<std::unique_ptr<Operator>> operators;
- // The associative array mapping names to Array's.
- // Notice how it's a container of unique_ptr's, implying
- // that the Model is what owns Array's and keeps them alive.
- // The Operator's refer to these Array's by their name strings, not by their
- // addresses. See Operator::inputs, Operator::outputs.
- std::unordered_map<string, std::unique_ptr<Array>> arrays;
+
// Generic flags, a place where we combine information passed to us via
// command-line parameters (e.g. --input_width=N) with information that
// we may or may not find in the input model file.
@@ -1552,7 +1591,15 @@ struct Model {
std::size_t transient_data_size = 0;
// For code-generation only: required alignment of the transient_data buffer
std::size_t transient_data_alignment = 0;
+
+ private:
+ // The associative array mapping names to Array's.
+ // Notice how it's a container of unique_ptr's, implying
+ // that the Model is what owns Array's and keeps them alive.
+ // The Operator's refer to these Array's by their name strings, not by their
+ // addresses. See Operator::inputs, Operator::outputs.
+ std::unordered_map<string, std::unique_ptr<Array>> arrays;
};
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 790b3443ce..4e2dec15a5 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -148,6 +148,12 @@ bool ParseModelFlagsFromCommandLineFlags(
"ranging from 32 to 127. This is disallowed by default so as to "
"catch common copy-and-paste issues where invisible unicode "
"characters are unwittingly added to these strings."),
+ Flag(
+ "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
+ parsed_flags.arrays_extra_info_file.default_value(),
+ "Path to an optional file containing a serialized ArraysExtraInfo "
+ "proto allowing to pass extra information about arrays not specified "
+ "in the input model file, such as extra MinMax information."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -327,9 +333,6 @@ void ReadModelFlagsFromCommandLineFlags(
CHECK(absl::SimpleAtoi(value, &size));
CHECK_GT(size, 0);
rnn_state_proto->set_size(size);
- } else if (key == "manually_create") {
- CHECK_EQ(absl::AsciiStrToLower(value), "true");
- rnn_state_proto->set_manually_create(true);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
}
@@ -368,6 +371,15 @@ void ReadModelFlagsFromCommandLineFlags(
parsed_model_flags.allow_nonascii_arrays.value());
model_flags->set_allow_nonexistent_arrays(
parsed_model_flags.allow_nonexistent_arrays.value());
+
+ if (parsed_model_flags.arrays_extra_info_file.specified()) {
+ string arrays_extra_info_file_contents;
+ port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
+ &arrays_extra_info_file_contents,
+ port::file::Defaults());
+ ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
+ model_flags->mutable_arrays_extra_info());
+ }
}
ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
index 027d7ae1aa..c868d5c7d0 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.h
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
#include <string>
#include <unordered_map>
@@ -40,5 +40,4 @@ ParsedModelFlags* GlobalParsedModelFlags();
} // namespace toco
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 13fea29a07..e4b39b34e8 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -81,19 +81,26 @@ message RnnState {
optional string state_array = 1;
optional string back_edge_source_array = 2;
optional bool discardable = 5;
- // TODO(benoitjacob): drop the 'size' field. Should be redundant with
- // --input_shapes and shapes propagation.
+ // size allows to specify a 1-D shape for the RNN state array.
+ // Will be expanded with 1's to fit the model.
+ // TODO(benoitjacob): should allow a generic, explicit shape.
optional int32 size = 3;
- // TODO(benoitjacob): manually_create is a temporary hack:
- // due to discrepancies between the current toco dims tracking and
- // TensorFlow shapes, for some models we need to manually create RNN state
- // arrays with a specified shape.
- // Maybe we should actually implement back-edges as operators of their own,
- // which would remove the need for much special-casing, including here,
- // we could probably consistently let PropagateFixedSizes handle state
- // arrays.
- // TODO(benoitjacob): should really drop manually_create now.
- optional bool manually_create = 4;
+}
+
+// An ArraysExtraInfo message stores a collection of additional Information
+// about arrays in a model, complementing the information in the model itself.
+// It is intentionally a separate message so that it may be serialized and
+// passed separately from the model. See --arrays_extra_info_file.
+//
+// A typical use case is to manually specify MinMax for specific arrays in a
+// model that does not already contain such MinMax information.
+message ArraysExtraInfo {
+ message Entry {
+ optional string name = 1;
+ optional float min = 2;
+ optional float max = 3;
+ }
+ repeated Entry entries = 1;
}
// ModelFlags encodes properties of a model that, depending on the file
@@ -117,7 +124,7 @@ message RnnState {
// optional int32 input_dims = 11 [ default = 4];
// repeated int32 input_shape = 13;
//
-// Next ID to USE: 18.
+// Next ID to USE: 19.
message ModelFlags {
// Information about the input arrays, i.e. the arrays from which input
// activations will be read.
@@ -160,4 +167,8 @@ message ModelFlags {
// catch common copy-and-paste issues where invisible unicode
// characters are unwittingly added to these strings.
optional bool allow_nonascii_arrays = 17;
+
+ // If set, this ArraysExtraInfo allows to pass extra information about arrays
+ // not specified in the input model file, such as extra MinMax information.
+ optional ArraysExtraInfo arrays_extra_info = 18;
}
diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h
index bd55544f57..3c6828840c 100644
--- a/tensorflow/contrib/lite/toco/runtime/common.h
+++ b/tensorflow/contrib/lite/toco/runtime/common.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
@@ -23,4 +23,4 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/common.h"
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_
diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h
index df63b2d59e..f5de5a5781 100644
--- a/tensorflow/contrib/lite/toco/runtime/types.h
+++ b/tensorflow/contrib/lite/toco/runtime/types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -29,4 +29,4 @@ using tflite::RequiredBufferSizeForDims;
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h
index 152b4f7a72..61f9104268 100644
--- a/tensorflow/contrib/lite/toco/tensorflow_util.h
+++ b/tensorflow/contrib/lite/toco/tensorflow_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
#include <string>
#include <vector>
@@ -29,4 +29,4 @@ void LogDumpGraphDef(int log_level, const string& message,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 332253a092..72c9266564 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -27,7 +27,7 @@ cc_library(
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
- "@flatbuffers//:flatbuffers",
+ "@flatbuffers",
],
)
@@ -41,7 +41,7 @@ tf_cc_test(
"//tensorflow/contrib/lite/toco:tooling_util",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
- "@flatbuffers//:flatbuffers",
+ "@flatbuffers",
],
)
@@ -87,7 +87,7 @@ cc_library(
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
"@com_google_absl//absl/strings",
- "@flatbuffers//:flatbuffers",
+ "@flatbuffers",
],
)
@@ -117,7 +117,7 @@ cc_library(
":types",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
- "@flatbuffers//:flatbuffers",
+ "@flatbuffers",
],
)
@@ -131,7 +131,7 @@ tf_cc_test(
"//tensorflow/contrib/lite:schema_fbs_version",
"//tensorflow/contrib/lite/schema:schema_fbs",
"@com_google_googletest//:gtest_main",
- "@flatbuffers//:flatbuffers",
+ "@flatbuffers",
],
)
diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
index 93cc79ddb6..cfe7ecd9f9 100644
--- a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
@@ -71,4 +71,4 @@ class BuiltinOperator : public BaseOperator {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
index 1a4bfac7d4..bd5713618f 100644
--- a/tensorflow/contrib/lite/toco/tflite/custom_operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
#include "flatbuffers/flexbuffers.h"
#include "absl/memory/memory.h"
@@ -71,4 +71,4 @@ class CustomOperator : public BaseOperator {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index bec694a233..391ef87029 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -62,7 +62,7 @@ namespace details {
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names.
std::set<string> names;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
names.insert(array_pair.first);
}
@@ -96,7 +96,7 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
// tensors in the tensors_map.
std::map<int, Offset<Tensor>> ordered_tensors;
- for (const auto& array_pair : model.arrays) {
+ for (const auto& array_pair : model.GetArrayMap()) {
const string& tensor_name = array_pair.first;
const toco::Array& array = *array_pair.second;
@@ -235,9 +235,10 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
for (const auto& op : model.operators) {
std::vector<int32_t> inputs;
for (const string& input : op->inputs) {
- inputs.push_back(tensors_map.at(input));
+ // -1 is the ID for optional tensor in TFLite output
+ int id = model.IsOptionalArray(input) ? -1 : tensors_map.at(input);
+ inputs.push_back(id);
}
-
std::vector<int32_t> outputs;
for (const string& output : op->outputs) {
outputs.push_back(tensors_map.at(output));
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 44012b7126..8c79cb8200 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
#include "tensorflow/contrib/lite/toco/model.h"
@@ -73,4 +73,4 @@ void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map);
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h
index 3c27a2843c..280677bae1 100644
--- a/tensorflow/contrib/lite/toco/tflite/import.h
+++ b/tensorflow/contrib/lite/toco/tflite/import.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -46,4 +46,4 @@ void LoadOperatorsTable(const ::tflite::Model &input_model,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc
index 309fa6d7f6..aad6e780d5 100644
--- a/tensorflow/contrib/lite/toco/tflite/import_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc
@@ -114,7 +114,7 @@ TEST_F(ImportTest, Tensors) {
auto model = Import(ModelFlags(), InputModelAsString());
- ASSERT_GT(model->arrays.count("tensor_one"), 0);
+ ASSERT_GT(model->HasArray("tensor_one"), 0);
Array& a1 = model->GetArray("tensor_one");
EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 0111e1ed92..298f49025f 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -474,19 +474,11 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- auto before_padding = builder->CreateVector(op.left_padding);
- auto after_padding = builder->CreateVector(op.right_padding);
- return ::tflite::CreatePadOptions(*builder, before_padding, after_padding);
+ return ::tflite::CreatePadOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
- op->left_padding.insert(op->left_padding.end(),
- options.before_padding()->begin(),
- options.before_padding()->end());
- op->right_padding.insert(op->right_padding.end(),
- options.after_padding()->begin(),
- options.after_padding()->end());
}
};
@@ -617,6 +609,30 @@ class Split : public CustomOperator<TensorFlowSplitOperator> {
}
};
+class StridedSlice
+ : public BuiltinOperator<StridedSliceOperator,
+ ::tflite::StridedSliceOptions,
+ ::tflite::BuiltinOptions_StridedSliceOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateStridedSliceOptions(
+ *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
+ op.new_axis_mask, op.shrink_axis_mask);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->begin_mask = options.begin_mask();
+ op->end_mask = options.end_mask();
+ op->ellipsis_mask = options.ellipsis_mask();
+ op->new_axis_mask = options.new_axis_mask();
+ op->shrink_axis_mask = options.shrink_axis_mask();
+ }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -777,6 +793,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
+ ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
+ OperatorType::kStridedSlice));
// Custom Operators.
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
@@ -789,6 +807,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
+ ops.emplace_back(
+ new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
"RSQRT", OperatorType::kTensorFlowRsqrt));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 37df302d46..88af3d6ab6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
#include "flatbuffers/flatbuffers.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
@@ -86,4 +86,4 @@ class BaseOperator {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 77c70847d1..9036a16d1c 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -258,16 +258,6 @@ TEST_F(OperatorTest, BuiltinMaxPool) {
EXPECT_EQ(op.kheight, output_toco_op->kheight);
}
-TEST_F(OperatorTest, BuiltinPad) {
- PadOperator op;
- op.left_padding = {1, 2, 3};
- op.right_padding = {1, 2, 3};
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("PAD", OperatorType::kPad), op);
- EXPECT_EQ(op.left_padding, output_toco_op->left_padding);
- EXPECT_EQ(op.right_padding, output_toco_op->right_padding);
-}
-
TEST_F(OperatorTest, BuiltinReshape) {
TensorFlowReshapeOperator op;
op.shape = {1, 2, 4, 5, 8};
@@ -398,6 +388,28 @@ TEST_F(OperatorTest, Squeeze) {
EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
}
+TEST_F(OperatorTest, StridedSlice) {
+ StridedSliceOperator op;
+
+ op.begin_mask = 1;
+ op.end_mask = 2;
+ op.ellipsis_mask = 1;
+ op.new_axis_mask = 1;
+ op.shrink_axis_mask = 2;
+
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
+ EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
+ EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
+ EXPECT_EQ(op.strides, output_toco_op->strides);
+ EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
+ EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
+ EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
+ EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
+ EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
+ EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
index 992b98baca..72678c82a2 100644
--- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
@@ -47,4 +47,4 @@ class SimpleOperator : public BaseOperator {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_
diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h
index f7c5140510..3923756fc9 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.h
+++ b/tensorflow/contrib/lite/toco/tflite/types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -55,4 +55,4 @@ struct ActivationFunction {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
index ba35ca8d5d..46eb3f5728 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
#include <string>
#include <vector>
@@ -33,4 +33,4 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
index ae0541f62b..d6c3ba6543 100644
--- a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
+++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
#include <string>
@@ -31,4 +31,4 @@ struct GraphVizDumpOptions {
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index b5cb7a11e7..4be3b5a0bf 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
// Portability layer for toco tool. Mainly, abstract filesystem access so we
// can build and use on google internal environments and on OSX.
#include <string>
+#include "google/protobuf/text_format.h"
#include "tensorflow/contrib/lite/toco/format_port.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
@@ -75,6 +76,26 @@ void CopyToBuffer(const ::Cord& src, char* dest);
#endif // PLATFORM_GOOGLE
void CopyToBuffer(const string& src, char* dest);
} // namespace port
+
+inline bool ParseFromStringOverload(const std::string& in,
+ TFLITE_PROTO_NS::Message* proto) {
+ return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+ Proto* proto) {
+ if (proto->ParseFromString(input_file_contents)) {
+ return true;
+ }
+
+ if (ParseFromStringOverload(input_file_contents, proto)) {
+ return true;
+ }
+
+ return false;
+}
+
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 94b4d14696..727df1cc76 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -52,7 +52,9 @@ void MakeGeneralGraphTransformationsSet(
GraphTransformationsSet* transformations) {
CHECK(transformations->empty());
transformations->Add(new ConvertExpandDimsToReshape);
+ transformations->Add(new ConvertTrivialAddNToAdd);
transformations->Add(new ConvertTrivialTransposeToReshape);
+ transformations->Add(new ConvertReorderAxes);
transformations->Add(new ResolveReshapeAttributes);
transformations->Add(new PropagateArrayDataTypes);
transformations->Add(new PropagateFixedSizes);
@@ -96,7 +98,6 @@ void MakeGeneralGraphTransformationsSet(
bool SupportsQuantization(FileFormat format) {
return (format == GRAPHVIZ_DOT || format == TFLITE);
- ;
}
bool SupportsFusedActivationFunction(FileFormat format) {
@@ -133,7 +134,7 @@ void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
for (int i = 0; i < model->flags.input_arrays_size(); i++) {
string const& array_name = model->flags.input_arrays(i).name();
- auto* array = model->arrays[array_name].get();
+ auto* array = &model->GetArray(array_name);
// Note that the notion of changing data types only applies to real-numbers
// arrays (see the documentation for inference_input_type).
// TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
@@ -192,6 +193,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model);
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
@@ -231,6 +233,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+
if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant});
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h
index 9c5a93a211..e731c149ee 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.h
+++ b/tensorflow/contrib/lite/toco/toco_tooling.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
#include <memory>
#include <string>
@@ -47,4 +47,4 @@ inline void Export(const TocoFlags& toco_flags, const Model& model,
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_
diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h
index ad42497ada..d72a3bd1f3 100644
--- a/tensorflow/contrib/lite/toco/toco_types.h
+++ b/tensorflow/contrib/lite/toco/toco_types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
#include <string>
#include "tensorflow/core/platform/platform.h"
@@ -42,4 +42,4 @@ using tensorflow::uint8;
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index e09a469d55..187c426a5b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -93,7 +93,7 @@ int CountOpsWithInput(const Model& model, const string& array_name) {
bool DeleteArrayIfUnused(const string& array_name, Model* model) {
if (CountOpsWithInput(*model, array_name) == 0) {
- model->arrays.erase(array_name);
+ model->EraseArray(array_name);
return true;
}
return false;
@@ -197,6 +197,7 @@ const char* OperatorTypeName(OperatorType type) {
case OperatorType::k##c: \
return #c;
HANDLE_OPERATORTYPENAME_CASE(Add)
+ HANDLE_OPERATORTYPENAME_CASE(AddN)
HANDLE_OPERATORTYPENAME_CASE(AveragePool)
HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
HANDLE_OPERATORTYPENAME_CASE(Conv)
@@ -566,11 +567,11 @@ int RequiredBufferSizeForShape(const Shape& shape) {
}
bool IsConstantParameterArray(const Model& model, const string& name) {
- if (!model.arrays.count(name)) {
+ if (!model.HasArray(name)) {
return false;
}
- return !!model.arrays.at(name)->buffer;
+ return !!model.GetArray(name).buffer;
}
namespace {
@@ -633,17 +634,17 @@ void CheckNonExistentIOArrays(const Model& model) {
return;
}
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.arrays.count(input_array.name()))
+ CHECK(model.HasArray(input_array.name()))
<< "Input array not found: " << input_array.name();
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.arrays.count(output_array))
+ CHECK(model.HasArray(output_array))
<< "Output array not found: " << output_array;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.arrays.count(rnn_state.state_array()));
- CHECK(model.arrays.count(rnn_state.back_edge_source_array()));
+ CHECK(model.HasArray(rnn_state.state_array()));
+ CHECK(model.HasArray(rnn_state.back_edge_source_array()));
}
}
}
@@ -652,10 +653,13 @@ void CheckNonExistentIOArrays(const Model& model) {
void CheckNoMissingArray(const Model& model) {
for (const auto& op : model.operators) {
for (const auto& input : op->inputs) {
- CHECK(model.arrays.count(input));
+ CHECK(model.HasArray(input) || model.optional_arrays.count(input))
+ << "Input: " << input << " missing for op: "
+ << op->outputs[0] << ".";
}
for (const auto& output : op->outputs) {
- CHECK(model.arrays.count(output));
+ CHECK(model.HasArray(output)) << "Output: " << output
+ << " missing.";
}
}
CheckNonExistentIOArrays(model);
@@ -664,12 +668,12 @@ void CheckNoMissingArray(const Model& model) {
void FixNoMissingArray(Model* model) {
for (const auto& op : model->operators) {
for (const auto& input : op->inputs) {
- if (!model->arrays.count(input)) {
+ if (!model->HasArray(input)) {
model->GetOrCreateArray(input);
}
}
for (const auto& output : op->outputs) {
- if (!model->arrays.count(output)) {
+ if (!model->HasArray(output)) {
model->GetOrCreateArray(output);
}
}
@@ -687,7 +691,7 @@ void FixNoMissingArray(Model* model) {
void CheckNoOrphanedArray(const Model& model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model.arrays) {
+ for (const auto& array : model.GetArrayMap()) {
if (IsDiscardableArray(model, array.first)) {
arrays_without_known_use.insert(array.first);
}
@@ -714,7 +718,7 @@ void CheckNoOrphanedArray(const Model& model) {
void FixNoOrphanedArray(Model* model) {
std::unordered_set<string> arrays_without_known_use;
- for (const auto& array : model->arrays) {
+ for (const auto& array : model->GetArrayMap()) {
arrays_without_known_use.insert(array.first);
}
for (const auto& op : model->operators) {
@@ -731,13 +735,13 @@ void FixNoOrphanedArray(Model* model) {
}
for (const auto& array : arrays_without_known_use) {
if (IsDiscardableArray(*model, array)) {
- model->arrays.erase(array);
+ model->EraseArray(array);
}
}
}
void CheckArrayFieldsConsistent(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
if (array->has_shape()) {
for (int d : array->shape().dims()) {
@@ -756,11 +760,13 @@ void CheckArrayFieldsConsistent(const Model& model) {
void CheckOperatorOrdering(const Model& model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
if (!GetOpWithOutput(model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
}
+ arrays_behind_us.insert(model.optional_arrays.begin(),
+ model.optional_arrays.end());
for (const auto& op : model.operators) {
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(model, input)) {
@@ -779,11 +785,13 @@ void CheckOperatorOrdering(const Model& model) {
void FixOperatorOrdering(Model* model) {
std::unordered_set<string> arrays_behind_us;
- for (const auto& array_entry : model->arrays) {
+ for (const auto& array_entry : model->GetArrayMap()) {
if (!GetOpWithOutput(*model, array_entry.first)) {
arrays_behind_us.insert(array_entry.first);
}
}
+ arrays_behind_us.insert(model->optional_arrays.begin(),
+ model->optional_arrays.end());
std::vector<std::unique_ptr<Operator>> old_operators;
std::swap(old_operators, model->operators);
std::set<std::size_t> remaining;
@@ -932,7 +940,8 @@ void CheckModelCounts(const Model& model) {
if (count_type == "None") {
continue;
} else if (count_type == "Arrays") {
- CheckCountInRange(model_check, model.arrays.size(), "count of arrays");
+ CheckCountInRange(model_check, model.GetArrayMap().size(),
+ "count of arrays");
} else if (count_type == "Total") {
CheckCountInRange(model_check, model.operators.size(),
"count of all operator instances");
@@ -952,7 +961,9 @@ void CheckModelCounts(const Model& model) {
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
std::vector<int>* out_dims) {
CHECK(out_dims->empty());
- if (num_dims == 1) {
+ if (num_dims == 0) {
+ return;
+ } else if (num_dims == 1) {
CHECK_EQ(batch, 1);
*out_dims = {depth};
} else if (num_dims == 2) {
@@ -984,13 +995,13 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
if (array.has_shape()) {
num_dims = array.shape().dimensions_count();
}
- std::vector<int> dims;
- MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
CHECK(array.data_type == ArrayDataType::kFloat ||
array.data_type == ArrayDataType::kNone);
array.data_type = ArrayDataType::kFloat;
- if (!array.has_shape()) {
+ if (!array.has_shape() && num_dims >= 0) {
Shape* shape = array.mutable_shape();
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
*shape->mutable_dims() = dims;
}
}
@@ -1179,9 +1190,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
- if (!rnn_state.manually_create()) {
- continue;
- }
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
model);
}
@@ -1195,6 +1203,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
model->flags.set_allow_nonexistent_arrays(
model_flags.allow_nonexistent_arrays());
+
+ CHECK(!model->flags.has_arrays_extra_info());
+ *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
}
void CheckIsReadyForQuantization(const Model& model) {
@@ -1281,6 +1292,8 @@ void DropMinMax(Model* model, const string& array_name) {
}
bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
+ // Optional array is not transient
+ if (model.IsOptionalArray(array_name)) return false;
// The model's input and output arrays are externally allocated.
// They are not transient arrays.
if (IsInputArray(model, array_name)) {
@@ -1291,7 +1304,7 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
return false;
}
}
- const auto& array = model.arrays.at(array_name);
+ const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
return false;
@@ -1304,13 +1317,13 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
}
string AvailableArrayName(const Model& model, const string& name) {
- if (!model.arrays.count(name)) {
+ if (!model.HasArray(name) && !model.optional_arrays.count(name)) {
return name;
}
const int kNumSuffixesToTry = 1000;
for (int i = 0; i < kNumSuffixesToTry; i++) {
const string& name_with_suffix = toco::port::StringF("%s_%d", name, i);
- if (!model.arrays.count(name_with_suffix)) {
+ if (!model.HasArray(name_with_suffix)) {
return name_with_suffix;
}
}
@@ -1328,12 +1341,12 @@ string ShapeToString(const Shape& shape) {
}
void PrintArrayShape(Model* model, const string& name) {
- if (!model->arrays[name]->has_shape()) {
+ if (!model->GetArray(name).has_shape()) {
LOG(INFO) << name << " has no shape";
return;
}
LOG(INFO) << name
- << " has shape: " << ShapeToString(model->arrays[name]->shape());
+ << " has shape: " << ShapeToString(model->GetArray(name).shape());
}
bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
@@ -1389,6 +1402,16 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
total += RequiredBufferSizeForShape(output_array.shape());
break;
}
+ case OperatorType::kAddN: {
+ const auto& output_array = model.GetArray(op->outputs[0]);
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // AddN cost is roughly the same cost as N-1 Adds.
+ const int num_adds = op->inputs.size() - 1;
+ total += num_adds * RequiredBufferSizeForShape(output_array.shape());
+ break;
+ }
case OperatorType::kLogistic:
case OperatorType::kSoftmax:
case OperatorType::kTanh: {
@@ -1456,8 +1479,6 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
return true;
}
-namespace {
-
void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
std::vector<int>* shuffle) {
CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
@@ -1492,6 +1513,8 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
}
}
+namespace {
+
// Extend shuffle is designed to match ExtendShape, which pads the shape with
// unit dimensions at the beginning.
void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
@@ -1667,7 +1690,7 @@ bool IsDiscardableArray(const Model& model, const string& array_name) {
}
void CheckFinalDataTypesSatisfied(const Model& model) {
- for (const auto& array_entry : model.arrays) {
+ for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = *array_entry.second;
if (array.final_data_type != ArrayDataType::kNone) {
CHECK(array.final_data_type == array.data_type)
@@ -1694,4 +1717,15 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
}
}
+void UseArraysExtraInfo(Model* model) {
+ for (const auto& entry : model->flags.arrays_extra_info().entries()) {
+ QCHECK(model->HasArray(entry.name()))
+ << "ArraysExtraInfo refers to non-existent array name: "
+ << entry.name();
+ auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
+ }
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index c81e77874e..2ac51c7e5b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
#include <algorithm>
#include <cmath>
@@ -23,7 +23,6 @@ limitations under the License.
#include <string>
#include <vector>
-#include "google/protobuf/text_format.h"
#include "tensorflow/core/platform/logging.h"
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
@@ -84,25 +83,6 @@ void DumpGraphvizVideoFrame(const Model& model);
void LogDump(int log_level, const string& message, const Model& model);
void LogSummary(int log_level, const string& message, const Model& model);
-inline bool ParseFromStringOverload(const std::string& in,
- TFLITE_PROTO_NS::Message* proto) {
- return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
-}
-
-template <typename Proto>
-bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
- Proto* proto) {
- if (proto->ParseFromString(input_file_contents)) {
- return true;
- }
-
- if (ParseFromStringOverload(input_file_contents, proto)) {
- return true;
- }
-
- return false;
-}
-
// TODO(b/36075966): Clean up when dims superseded by array shape.
void ExtendShape(Shape* shape, int new_shape_size);
@@ -274,6 +254,11 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result);
int AxesCount(AxesOrder axes_order);
+// Returns the permutation of the dimensions based on the input axes order and
+// output axes order.
+void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
+ std::vector<int>* shuffle);
+
void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
AxesOrder output_axes_order, Shape* output_shape);
void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
@@ -293,6 +278,8 @@ void CheckFinalDataTypesSatisfied(const Model& model);
ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
+void UseArraysExtraInfo(Model* model);
+
} // namespace toco
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 389ef2323a..1bffcfb987 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -42,6 +42,8 @@ tf_cc_binary(
}),
deps = [
":mutable_op_resolver",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
@@ -91,3 +93,28 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
+
+cc_library(
+ name = "verifier",
+ srcs = ["verifier.cc"],
+ hdrs = ["verifier.h"],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "verifier_test",
+ size = "small",
+ srcs = ["verifier_test.cc"],
+ deps = [
+ ":verifier",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h
index 318859e23d..5f2ac6ca97 100644
--- a/tensorflow/contrib/lite/tools/gen_op_registration.h
+++ b/tensorflow/contrib/lite/tools/gen_op_registration.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string.h"
@@ -36,4 +36,4 @@ void ReadOpsFromModel(const ::tflite::Model* model,
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
index 906553da57..573a359c45 100644
--- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
#include <map>
#include "tensorflow/contrib/lite/context.h"
@@ -52,4 +52,4 @@ class MutableOpResolver : public OpResolver {
} // namespace tflite
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc
new file mode 100644
index 0000000000..95a0895379
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/verifier.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+namespace {
+
+const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) {
+ ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
+ if (VerifyModelBuffer(verifier)) {
+ return ::tflite::GetModel(buf);
+ } else {
+ return nullptr;
+ }
+}
+
+} // namespace
+
+bool Verify(const void* buf, size_t len) {
+ const Model* model = VerifyFlatbufferAndGetModel(buf, len);
+ if (model == nullptr) {
+ return false;
+ }
+
+ return model->version() == TFLITE_SCHEMA_VERSION;
+}
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h
new file mode 100644
index 0000000000..03e1f22b7e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
+
+#include <stdio.h>
+
+namespace tflite {
+
+// Verifies the integrity of a Tensorflow Lite flatbuffer model file.
+// Currently, it verifies:
+// * The file is following a legit flatbuffer schema.
+// * The model is in supported version.
+bool Verify(const void* buf, size_t len);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
new file mode 100644
index 0000000000..0481a55a78
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -0,0 +1,136 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/verifier.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/util.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+// Class that abstracts the list of buffers at the end of the TF Lite structure
+class DeferredBufferWriter {
+ public:
+ DeferredBufferWriter() {
+ data_.push_back({}); // sentinel empty buffer.
+ }
+
+ Offset<Vector<Offset<Buffer>>> BuildBuffers(FlatBufferBuilder *builder) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (const auto &vec : data_) {
+ auto data_buffer = builder->CreateVector(vec.data(), vec.size());
+ buffer_vector.push_back(tflite::CreateBuffer(*builder, data_buffer));
+ }
+ return builder->CreateVector(buffer_vector);
+ }
+
+ // Registers a buffer index and takes ownership of the data to write to it.
+ int Record(std::vector<uint8_t> data) {
+ int buffer_index = data_.size();
+ data_.emplace_back(std::move(data));
+ return buffer_index;
+ }
+
+ private:
+ std::vector<std::vector<unsigned char>> data_;
+};
+
+TEST(VerifyModel, TestEmptyModel) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION,
+ /*operator_codes=*/0, /*subgraphs=*/0,
+ /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+
+ ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestSimpleModel) {
+ FlatBufferBuilder builder;
+ auto inputs = builder.CreateVector<int32_t>({0});
+ auto outputs = builder.CreateVector<int32_t>({1});
+ auto operator_codes = builder.CreateVector(std::vector<Offset<OperatorCode>>{
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "test")});
+ auto operators =
+ builder.CreateVector(std::vector<Offset<Operator>>{CreateOperator(
+ builder, /*opcode_index=*/0,
+ /*inputs=*/builder.CreateVector<int32_t>({0}),
+ /*outputs=*/builder.CreateVector<int32_t>({1}), BuiltinOptions_NONE,
+ /*builtin_options=*/0,
+ /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS)});
+ std::vector<int> shape;
+ auto tensors = builder.CreateVector(std::vector<Offset<Tensor>>{
+ CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
+ "input", /*quantization=*/0),
+ CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
+ "output", /*quantization=*/0)});
+ auto subgraph = std::vector<Offset<SubGraph>>(
+ {CreateSubGraph(builder, tensors, inputs, outputs, operators,
+ builder.CreateString("Main"))});
+
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, operator_codes,
+ builder.CreateVector(subgraph),
+ builder.CreateString("SmartReply"), /*buffers=*/0);
+
+ ::tflite::FinishModelBuffer(builder, model);
+ ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestCorruptedData) {
+ string model = "123";
+ ASSERT_FALSE(Verify(model.data(), model.size()));
+}
+
+TEST(VerifyModel, TestUnsupportedVersion) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/1, /*operator_codes=*/0,
+ /*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+ ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION,
+ /*operator_codes=*/0,
+ /*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+
+ string model_content(reinterpret_cast<char *>(builder.GetBufferPointer()),
+ builder.GetSize());
+ for (int i = 0; i < model_content.size(); i++) {
+ model_content[i] = (model_content[i] + 137) % 255;
+ EXPECT_FALSE(Verify(model_content.data(), model_content.size()))
+ << "Fail at position: " << i;
+ }
+}
+
+// TODO(yichengfan): make up malicious files to test with.
+
+} // namespace tflite
+
+int main(int argc, char **argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h
index a751afabe7..efd63f4006 100644
--- a/tensorflow/contrib/lite/version.h
+++ b/tensorflow/contrib/lite/version.h
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_VERSION_H_
+#define TENSORFLOW_CONTRIB_LITE_VERSION_H_
// The version number of the Schema. Ideally all changes will be backward
// compatible. If that ever changes, we must ensure that version is the first
// entry in the new tflite root so that we can see that version is not 1.
#define TFLITE_SCHEMA_VERSION (3)
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_
+#endif // TENSORFLOW_CONTRIB_LITE_VERSION_H_
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index dd5770dc99..c50f8ceec0 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -377,10 +377,10 @@ $(MARCH_OPTION) \
ifeq ($(BUILD_FOR_TEGRA),1)
NVCC := $(JETPACK)/cuda/bin/nvcc
- NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DNVIDIA_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3
+ NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DANDROID_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3
CXXFLAGS4NVCC =\
-DIS_SLIM_BUILD \
--DNVIDIA_TEGRA \
+-DANDROID_TEGRA \
-fno-exceptions \
-DNDEBUG $(OPTFLAGS) \
-march=armv8-a \
@@ -391,7 +391,7 @@ $(MARCH_OPTION) \
CXXFLAGS +=\
-DGOOGLE_CUDA=1 \
-D__ANDROID_TYPES_FULL__ \
--DNVIDIA_TEGRA \
+-DANDROID_TEGRA \
-DEIGEN_AVOID_STL_ARRAY \
-DEIGEN_HAS_C99_MATH \
-DLANG_CXX11 -DTENSORFLOW_USE_EIGEN_THREADPOOL -DTF_EXTRA_CUDA_CAPABILITIES=5.3
diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh
index 980a44a595..281c4653c6 100755
--- a/tensorflow/contrib/makefile/build_all_android.sh
+++ b/tensorflow/contrib/makefile/build_all_android.sh
@@ -18,7 +18,7 @@
set -e
usage() {
- echo "Usage: NDK_ROOT=<path to ndk root> $(basename "$0") [-Es:t:Tx:a:X]"
+ echo "Usage: NDK_ROOT=<path to ndk root> $(basename "$0") [-Es:t:Tx:a]"
echo "-E enable experimental hexnn ops"
echo "-s [sub_makefiles] sub makefiles separated by white space"
echo "-t [build_target] build target for Android makefile [default=all]"
@@ -37,7 +37,7 @@ fi
ARCH=armeabi-v7a
-while getopts "Es:t:Tx:a:" opt_name; do
+while getopts "Es:t:Tx:a" opt_name; do
case "$opt_name" in
E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";;
s) SUB_MAKEFILES="${OPTARG}";;
diff --git a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
index 861bb885c7..203ff4f890 100755
--- a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
+++ b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
@@ -76,6 +76,8 @@ GEN_LIBS_DIR="${GEN_DIR}/libs"
GEN_DOWNLOAD_DIR="${GEN_DIR}/downloads"
URL_BASE="https://storage.googleapis.com/download.tensorflow.org"
+ARCH="armeabi-v7a"
+
source "${SCRIPT_DIR}/../build_helper.subr"
rm -rf "${GEN_DIR}"
@@ -219,7 +221,7 @@ if [[ "${BUILD_ONLY}" != "true" ]]; then
adb push "${GEN_LIBS_DIR}/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp"
adb push -p \
- "${TF_ROOT_DIR}/tensorflow/contrib/makefile/gen/bin/hexagon_graph_execution" \
+ "${TF_ROOT_DIR}/tensorflow/contrib/makefile/gen/bin/android_${ARCH}/hexagon_graph_execution" \
"/data/local/tmp/"
adb wait-for-device
adb shell chmod "${ANDROID_EXEC_FILE_MODE}" \
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index c3de1c4c62..55946c128b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -339,9 +339,9 @@ def streaming_mean_tensor(values,
name=name)
-@deprecated(
- None, 'Please switch to tf.metrics.accuracy. Note that the order of the '
- 'labels and predictions arguments has been switched.')
+@deprecated(None,
+ 'Please switch to tf.metrics.accuracy. Note that the order of the '
+ 'labels and predictions arguments has been switched.')
def streaming_accuracy(predictions,
labels,
weights=None,
@@ -936,8 +936,9 @@ def streaming_curve_points(labels=None,
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = _EPSILON # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds - 2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
@@ -973,9 +974,8 @@ def streaming_curve_points(labels=None,
return points, update_op
-@deprecated(
- None, 'Please switch to tf.metrics.auc. Note that the order of the '
- 'labels and predictions arguments has been switched.')
+@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the '
+ 'labels and predictions arguments has been switched.')
def streaming_auc(predictions,
labels,
weights=None,
@@ -1105,8 +1105,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
# For conformance, set precision to 1 when the number of positive
# classifications is 0.
y_axis_values = array_ops.where(
- math_ops.greater(splits, 0),
- math_ops.truediv(true_positives, splits),
+ math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits),
array_ops.ones_like(true_positives, dtype=dtypes.float64))
# Calculate trapezoid areas.
@@ -1119,9 +1118,8 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
# exception seems excessive) so we return 0, otherwise we finish computing.
return control_flow_ops.cond(
math_ops.logical_or(
- math_ops.equal(total_positive, 0),
- math_ops.equal(total_positive, size)
- ),
+ math_ops.equal(total_positive, 0), math_ops.equal(
+ total_positive, size)),
true_fn=lambda: array_ops.constant(0, dtypes.float64),
false_fn=continue_computing_dynamic_auc)
@@ -1185,10 +1183,10 @@ def streaming_dynamic_auc(labels,
array_ops.ones_like(labels, dtypes.int64),
message='labels must be 0 or 1, at least one is >1')
]):
- preds_accum, update_preds = streaming_concat(predictions,
- name='concat_preds')
- labels_accum, update_labels = streaming_concat(labels,
- name='concat_labels')
+ preds_accum, update_preds = streaming_concat(
+ predictions, name='concat_preds')
+ labels_accum, update_labels = streaming_concat(
+ labels, name='concat_labels')
update_op = control_flow_ops.group(update_labels, update_preds)
auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
if updates_collections:
@@ -1571,9 +1569,9 @@ def streaming_precision_at_thresholds(predictions,
name=name)
-@deprecated(
- None, 'Please switch to tf.metrics.recall_at_thresholds. Note that the '
- 'order of the labels and predictions arguments has been switched.')
+@deprecated(None,
+ 'Please switch to tf.metrics.recall_at_thresholds. Note that the '
+ 'order of the labels and predictions arguments has been switched.')
def streaming_recall_at_thresholds(predictions,
labels,
thresholds,
@@ -3299,8 +3297,13 @@ def count(values,
return count_, update_op
-def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
- metrics_collections=None, updates_collections=None, name=None):
+def cohen_kappa(labels,
+ predictions_idx,
+ num_classes,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Calculates Cohen's kappa.
[Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
@@ -3367,14 +3370,15 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
labels = array_ops.squeeze(labels, axis=[-1])
predictions_idx, labels, weights = (
metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
- predictions=predictions_idx, labels=labels, weights=weights))
+ predictions=predictions_idx,
+ labels=labels,
+ weights=weights))
predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
- stat_dtype = (dtypes.int64
- if weights is None or weights.dtype.is_integer
- else dtypes.float32)
- po = metrics_impl.metric_variable(
- (num_classes,), stat_dtype, name='po')
+ stat_dtype = (
+ dtypes.int64
+ if weights is None or weights.dtype.is_integer else dtypes.float32)
+ po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po')
pe_row = metrics_impl.metric_variable(
(num_classes,), stat_dtype, name='pe_row')
pe_col = metrics_impl.metric_variable(
@@ -3382,9 +3386,12 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
# Table of the counts of agreement:
counts_in_table = confusion_matrix.confusion_matrix(
- labels, predictions_idx,
- num_classes=num_classes, weights=weights,
- dtype=stat_dtype, name="counts_in_table")
+ labels,
+ predictions_idx,
+ num_classes=num_classes,
+ weights=weights,
+ dtype=stat_dtype,
+ name='counts_in_table')
po_t = array_ops.diag_part(counts_in_table)
pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
@@ -3404,12 +3411,14 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
math_ops.to_double(total))
# kappa = (po - pe) / (N - pe)
k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access
- po_sum - pe_sum, total - pe_sum, name=name)
+ po_sum - pe_sum,
+ total - pe_sum,
+ name=name)
return k
kappa = _calculate_k(po, pe_row, pe_col, name='value')
- update_op = _calculate_k(update_po, update_pe_row, update_pe_col,
- name='update_op')
+ update_op = _calculate_k(
+ update_po, update_pe_row, update_pe_col, name='update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, kappa)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 89aa29f711..e067f08bab 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -46,8 +46,7 @@ def _enqueue_vector(sess, queue, values, shape=None):
shape = (1, len(values))
dtype = queue.dtypes[0]
sess.run(
- queue.enqueue(constant_op.constant(
- values, dtype=dtype, shape=shape)))
+ queue.enqueue(constant_op.constant(values, dtype=dtype, shape=shape)))
def _binary_2d_label_to_sparse_value(labels):
@@ -79,8 +78,8 @@ def _binary_2d_label_to_sparse_value(labels):
batch += 1
shape = [len(labels), len(labels[0])]
return sparse_tensor.SparseTensorValue(
- np.array(indices, np.int64),
- np.array(values, np.int64), np.array(shape, np.int64))
+ np.array(indices, np.int64), np.array(values, np.int64),
+ np.array(shape, np.int64))
def _binary_2d_label_to_sparse(labels):
@@ -125,8 +124,8 @@ def _binary_3d_label_to_sparse_value(labels):
assert label == 0
shape = [len(labels), len(labels[0]), len(labels[0][0])]
return sparse_tensor.SparseTensorValue(
- np.array(indices, np.int64),
- np.array(values, np.int64), np.array(shape, np.int64))
+ np.array(indices, np.int64), np.array(values, np.int64),
+ np.array(shape, np.int64))
def _binary_3d_label_to_sparse(labels):
@@ -669,20 +668,18 @@ class StreamingTruePositivesTest(test.TestCase):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- tp, tp_update_op = metrics.streaming_true_positives(predictions,
- labels)
+ tp, tp_update_op = metrics.streaming_true_positives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -692,14 +689,12 @@ class StreamingTruePositivesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
@@ -717,28 +712,25 @@ class StreamingFalseNegativesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_false_negatives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_false_negatives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('false_negatives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- fn, fn_update_op = metrics.streaming_false_negatives(predictions,
- labels)
+ fn, fn_update_op = metrics.streaming_false_negatives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -748,14 +740,12 @@ class StreamingFalseNegativesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
@@ -773,28 +763,25 @@ class StreamingFalsePositivesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_false_positives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_false_positives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('false_positives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- fp, fp_update_op = metrics.streaming_false_positives(predictions,
- labels)
+ fp, fp_update_op = metrics.streaming_false_positives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -804,20 +791,17 @@ class StreamingFalsePositivesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
fp, fp_update_op = metrics.streaming_false_positives(
predictions,
labels,
- weights=((1.0, 2.0, 3.0, 5.0),
- (7.0, 11.0, 13.0, 17.0),
- (19.0, 23.0, 29.0, 31.0)))
+ weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
+ 29.0, 31.0)))
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -833,28 +817,25 @@ class StreamingTrueNegativesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_true_negatives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_true_negatives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('true_negatives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- tn, tn_update_op = metrics.streaming_true_negatives(predictions,
- labels)
+ tn, tn_update_op = metrics.streaming_true_negatives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -864,14 +845,12 @@ class StreamingTrueNegativesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
@@ -894,12 +873,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('true_positives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -910,12 +886,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
self.assertAllEqual((3, 1, 0), tp.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
@@ -937,16 +910,14 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
(0.0, 1.0, 0.0), (0, 1, 1), thresholds=(
0.15,
0.5,
- 0.85,))
+ 0.85,
+ ))
_assert_metric_variables(self, ('false_negatives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -957,12 +928,9 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
self.assertAllEqual((0, 2, 3), fn.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions,
labels,
@@ -988,12 +956,9 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('false_positives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -1004,18 +969,14 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
self.assertAllEqual((7, 4, 2), fp.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions,
labels,
- weights=((1.0, 2.0, 3.0, 5.0),
- (7.0, 11.0, 13.0, 17.0),
- (19.0, 23.0, 29.0, 31.0)),
+ weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
+ 29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
with self.test_session() as sess:
@@ -1037,12 +998,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('true_negatives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -1053,12 +1011,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
self.assertAllEqual((2, 5, 7), tn.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions,
labels,
@@ -1393,8 +1348,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1413,8 +1367,7 @@ class StreamingFPRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(np_inputs)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1424,8 +1377,7 @@ class StreamingFPRTest(test.TestCase):
def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1467,8 +1419,7 @@ class StreamingFPRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(1 - np_inputs)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1478,8 +1429,7 @@ class StreamingFPRTest(test.TestCase):
def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self):
predictions = array_ops.ones((1, 4))
labels = array_ops.ones((1, 4))
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1521,8 +1471,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1541,8 +1490,7 @@ class StreamingFNRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(np_inputs)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1552,8 +1500,7 @@ class StreamingFNRTest(test.TestCase):
def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1595,8 +1542,7 @@ class StreamingFNRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(1 - np_inputs)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1606,8 +1552,7 @@ class StreamingFNRTest(test.TestCase):
def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self):
predictions = array_ops.zeros((1, 4))
labels = array_ops.zeros((1, 4))
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1944,16 +1889,17 @@ class StreamingAUCTest(test.TestCase):
enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :]))
return x_queue.dequeue()
- for weights in (None, np.ones(num_samples), np.random.exponential(
- scale=1.0, size=num_samples)):
+ for weights in (None, np.ones(num_samples),
+ np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
with self.test_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
- tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if
- weights is not None else None)
+ tf_weights = (
+ _enqueue_as_batches(weights, enqueue_ops)
+ if weights is not None else None)
for i in range(num_batches):
sess.run(enqueue_ops[i])
@@ -1985,17 +1931,18 @@ class StreamingDynamicAUCTest(test.TestCase):
def testUnknownCurve(self):
with self.assertRaisesRegexp(
ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
- metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
- predictions=array_ops.ones((10, 1)),
- curve='TEST_CURVE')
+ metrics.streaming_dynamic_auc(
+ labels=array_ops.ones((10, 1)),
+ predictions=array_ops.ones((10, 1)),
+ curve='TEST_CURVE')
def testVars(self):
metrics.streaming_dynamic_auc(
labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
- _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
- 'dynamic_auc/concat_labels/size:0',
- 'dynamic_auc/concat_preds/array:0',
- 'dynamic_auc/concat_preds/size:0'])
+ _assert_metric_variables(self, [
+ 'dynamic_auc/concat_labels/array:0', 'dynamic_auc/concat_labels/size:0',
+ 'dynamic_auc/concat_preds/array:0', 'dynamic_auc/concat_preds/size:0'
+ ])
def testMetricsCollection(self):
my_collection_name = '__metrics__'
@@ -2049,8 +1996,8 @@ class StreamingDynamicAUCTest(test.TestCase):
def testNonZeroOnePredictions(self):
with self.test_session() as sess:
- predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
- dtype=dtypes_lib.float32)
+ predictions = constant_op.constant(
+ [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
@@ -2122,9 +2069,10 @@ class StreamingDynamicAUCTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.int32)
+ tf_labels = variables.Variable(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
tf_predictions = variables.Variable(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -2195,8 +2143,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
gotten_result: A PrecisionRecallData object.
"""
gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
- self.assertItemsEqual(
- list(expected_dict.keys()), list(gotten_dict.keys()))
+ self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys()))
for key, expected_values in expected_dict.items():
self.assertAllClose(expected_values, gotten_dict[key])
@@ -2261,60 +2208,65 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
sess.run(update_op)
# Then verify idempotency.
- initial_result = {k: value.eval().tolist() for k, value in
- result._asdict().items()}
+ initial_result = {
+ k: value.eval().tolist()
+ for k, value in result._asdict().items()
+ }
for _ in range(3):
self._testResultsEqual(initial_result, result)
def testAllTruePositives(self):
- self._testCase([[1]], [[True]], {
- 'tp': [1, 1, 1],
- 'fp': [0, 0, 0],
- 'tn': [0, 0, 0],
- 'fn': [0, 0, 0],
- 'precision': [1.0, 1.0, 1.0],
- 'recall': [1.0, 1.0, 1.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[1]], [[True]], {
+ 'tp': [1, 1, 1],
+ 'fp': [0, 0, 0],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 0, 0],
+ 'precision': [1.0, 1.0, 1.0],
+ 'recall': [1.0, 1.0, 1.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllTrueNegatives(self):
- self._testCase([[0]], [[False]], {
- 'tp': [0, 0, 0],
- 'fp': [1, 0, 0],
- 'tn': [0, 1, 1],
- 'fn': [0, 0, 0],
- 'precision': [0.0, 0.0, 0.0],
- 'recall': [0.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[0]], [[False]], {
+ 'tp': [0, 0, 0],
+ 'fp': [1, 0, 0],
+ 'tn': [0, 1, 1],
+ 'fn': [0, 0, 0],
+ 'precision': [0.0, 0.0, 0.0],
+ 'recall': [0.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllFalsePositives(self):
- self._testCase([[1]], [[False]], {
- 'tp': [0, 0, 0],
- 'fp': [1, 1, 1],
- 'tn': [0, 0, 0],
- 'fn': [0, 0, 0],
- 'precision': [0.0, 0.0, 0.0],
- 'recall': [0.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[1]], [[False]], {
+ 'tp': [0, 0, 0],
+ 'fp': [1, 1, 1],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 0, 0],
+ 'precision': [0.0, 0.0, 0.0],
+ 'recall': [0.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllFalseNegatives(self):
- self._testCase([[0]], [[True]], {
- 'tp': [1, 0, 0],
- 'fp': [0, 0, 0],
- 'tn': [0, 0, 0],
- 'fn': [0, 1, 1],
- 'precision': [1.0, 0.0, 0.0],
- 'recall': [1.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[0]], [[True]], {
+ 'tp': [1, 0, 0],
+ 'fp': [0, 0, 0],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 1, 1],
+ 'precision': [1.0, 0.0, 0.0],
+ 'recall': [1.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testManyValues(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
- [[True, False, False, True, True, True]],
- {
+ [[True, False, False, True, True, True]], {
'tp': [4, 3, 0],
'fp': [2, 0, 0],
'tn': [0, 2, 2],
@@ -2327,8 +2279,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def testManyValuesWithWeights(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
- [[True, False, False, True, True, True]],
- {
+ [[True, False, False, True, True, True]], {
'tp': [1.5, 1.5, 0.0],
'fp': [2.5, 0.0, 0.0],
'tn': [0.0, 2.5, 2.5],
@@ -2644,11 +2595,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -2672,11 +2622,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2690,11 +2639,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2709,11 +2657,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2779,11 +2726,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
thresholds = [-1.0, 2.0] # lower/higher than any values
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
prec_low = prec[0]
prec_high = prec[1]
@@ -2803,11 +2749,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2872,12 +2817,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
tf_predictions = predictions_queue.dequeue()
tf_labels = labels_queue.dequeue()
- prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions,
- tf_labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions,
- tf_labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
sess.run(variables.local_variables_initializer())
for _ in range(int(num_samples / batch_size)):
@@ -2921,8 +2864,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels=array_ops.ones((10, 1)),
thresholds=[0, 0.5, 1.0],
updates_collections=[my_collection_name])
- self.assertListEqual(
- ops.get_collection(my_collection_name), [update_op])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
@@ -3271,8 +3213,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels=array_ops.ones((10, 1)),
thresholds=[0, 0.5, 1.0],
updates_collections=[my_collection_name])
- self.assertListEqual(
- ops.get_collection(my_collection_name), [update_op])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
@@ -3492,8 +3433,7 @@ class StreamingRecallAtKTest(test.TestCase):
def testVars(self):
metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1)
_assert_metric_variables(self,
('recall_at_1/count:0', 'recall_at_1/total:0'))
@@ -3502,8 +3442,7 @@ class StreamingRecallAtKTest(test.TestCase):
my_collection_name = '__metrics__'
mean, _ = metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1,
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [mean])
@@ -3512,8 +3451,7 @@ class StreamingRecallAtKTest(test.TestCase):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1,
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -3715,9 +3653,17 @@ class StreamingSparsePrecisionTest(test.TestCase):
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=np.array([[0,], [1,], [2,]], np.int64),
+ indices=np.array([[
+ 0,
+ ], [
+ 1,
+ ], [
+ 2,
+ ]], np.int64),
values=np.array([2, 7, 8], np.int64),
- dense_shape=np.array([10,], np.int64))
+ dense_shape=np.array([
+ 10,
+ ], np.int64))
with self.assertRaises(ValueError):
precision, _ = metrics.streaming_sparse_precision_at_top_k(
@@ -3774,8 +3720,9 @@ class StreamingSparsePrecisionTest(test.TestCase):
# average of the 2 examples.
labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
predictions = (predictions_ex1, predictions_ex2)
- streaming_precision = [(ex1 + ex2) / 2
- for ex1, ex2 in zip(precision_ex1, precision_ex2)]
+ streaming_precision = [
+ (ex1 + ex2) / 2 for ex1, ex2 in zip(precision_ex1, precision_ex2)
+ ]
streaming_average_precision = [
(ex1 + ex2) / 2
for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
@@ -3835,29 +3782,29 @@ class StreamingSparsePrecisionTest(test.TestCase):
(predictions_top_k_ex1[:k],), labels, expected=avg_precision_ex1[i])
def test_average_precision_at_top_k_static_shape_check(self):
- predictions_top_k = array_ops.placeholder(shape=(2, None),
- dtype=dtypes_lib.int64)
+ predictions_top_k = array_ops.placeholder(
+ shape=(2, None), dtype=dtypes_lib.int64)
labels = np.array(((1,), (2,)), dtype=np.int64)
# Fails due to non-static predictions_idx shape.
with self.assertRaises(ValueError):
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
predictions_top_k = (2, 1)
# Fails since rank of predictions_idx is less than one.
with self.assertRaises(ValueError):
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
predictions_top_k = ((2,), (1,))
# Valid static shape.
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -3871,8 +3818,8 @@ class StreamingSparsePrecisionTest(test.TestCase):
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -3971,8 +3918,8 @@ class StreamingSparsePrecisionTest(test.TestCase):
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
- [1, 3]],
+ indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1,
+ 3]],
# values -1 and 10 are outside the [0, n_classes) range and are ignored.
values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
dense_shape=[2, 4])
@@ -4324,8 +4271,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
# Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
@@ -4340,8 +4287,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4354,8 +4301,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4374,8 +4321,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4647,8 +4594,8 @@ class StreamingSparseRecallTest(test.TestCase):
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
- [1, 3]],
+ indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1,
+ 3]],
# values -1 and 10 are outside the [0, n_classes) range.
values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
dense_shape=[2, 4])
@@ -4661,10 +4608,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2,
class_id=2)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=2.0 / 2,
- class_id=2)
+ sp_labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -4674,10 +4618,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=5)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=5)
+ sp_labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -4687,10 +4628,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1,
class_id=7)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=0.0 / 1,
- class_id=7)
+ sp_labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
@@ -4740,10 +4678,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
dense_labels = np.array(
[[[2, 7, 8], [1, 2, 5]], [
[1, 2, 5],
@@ -4771,10 +4707,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
# Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k(
@@ -4813,10 +4747,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
for class_id in xrange(10):
self._test_streaming_sparse_recall_at_k(
@@ -4867,10 +4799,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@@ -4963,10 +4893,8 @@ class StreamingSparseRecallTest(test.TestCase):
weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self):
- predictions = [[0.1, 0.3, 0.2, 0.4],
- [0.1, 0.2, 0.3, 0.4]]
- labels = [[0, 0, 1, 0],
- [0, 0, 0, 1]]
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
with self.test_session():
_, recall = metrics.streaming_sparse_recall_at_k(
@@ -5009,8 +4937,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
- error, update_op = metrics.streaming_mean_absolute_error(predictions,
- labels)
+ error, update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5031,8 +4959,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- error, update_op = metrics.streaming_mean_absolute_error(predictions,
- labels, weights)
+ error, update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels, weights)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5075,8 +5003,8 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
normalizer = random_ops.random_normal((10, 3), seed=3)
- error, update_op = metrics.streaming_mean_relative_error(predictions,
- labels, normalizer)
+ error, update_op = metrics.streaming_mean_relative_error(
+ predictions, labels, normalizer)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5200,8 +5128,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- error, update_op = metrics.streaming_mean_squared_error(predictions, labels,
- weights)
+ error, update_op = metrics.streaming_mean_squared_error(
+ predictions, labels, weights)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5224,8 +5152,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
_enqueue_vector(sess, labels_queue, [2, 4, 6])
labels = labels_queue.dequeue()
- error, update_op = metrics.streaming_mean_squared_error(predictions,
- labels)
+ error, update_op = metrics.streaming_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -5292,10 +5220,10 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
_enqueue_vector(sess, labels_queue, [2, 4, 6])
labels = labels_queue.dequeue()
- mae, ma_update_op = metrics.streaming_mean_absolute_error(predictions,
- labels)
- mse, ms_update_op = metrics.streaming_mean_squared_error(predictions,
- labels)
+ mae, ma_update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels)
+ mse, ms_update_op = metrics.streaming_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
sess.run([ma_update_op, ms_update_op])
@@ -5336,8 +5264,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
- error, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ error, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5357,8 +5285,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
@@ -5372,8 +5300,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
labels = constant_op.constant(
[1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5)
@@ -5387,9 +5315,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels,
- weights)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels, weights)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(math.sqrt(13), sess.run(update_op))
@@ -5404,8 +5331,8 @@ class StreamingCovarianceTest(test.TestCase):
def testVars(self):
metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]))
_assert_metric_variables(self, (
'covariance/comoment:0',
@@ -5417,8 +5344,8 @@ class StreamingCovarianceTest(test.TestCase):
def testMetricsCollection(self):
my_collection_name = '__metrics__'
cov, _ = metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [cov])
@@ -5426,8 +5353,8 @@ class StreamingCovarianceTest(test.TestCase):
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -5487,9 +5414,8 @@ class StreamingCovarianceTest(test.TestCase):
cov, update_op = metrics.streaming_covariance(
predictions, labels, weights=weights)
- expected_cov = np.cov([2, 4, 6, 8],
- [1, 3, 2, 7],
- fweights=[0, 1, 3, 1])[0, 1]
+ expected_cov = np.cov(
+ [2, 4, 6, 8], [1, 3, 2, 7], fweights=[0, 1, 3, 1])[0, 1]
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(expected_cov, sess.run(update_op))
self.assertAlmostEqual(expected_cov, cov.eval())
@@ -5514,17 +5440,18 @@ class StreamingCovarianceTest(test.TestCase):
predictions_t: predictions[stride * i:stride * (i + 1)],
labels_t: labels[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_cov),
- np.isnan(sess.run(cov, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_cov),
+ np.isnan(sess.run(cov, feed_dict=feed_dict)))
if not np.isnan(prev_expected_cov):
- self.assertAlmostEqual(
- prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_cov,
+ sess.run(cov, feed_dict=feed_dict), 5)
expected_cov = np.cov(predictions[:stride * (i + 1)],
labels[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_cov, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict),
+ 5)
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
@@ -5552,18 +5479,20 @@ class StreamingCovarianceTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_cov),
- np.isnan(sess.run(cov, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_cov),
+ np.isnan(sess.run(cov, feed_dict=feed_dict)))
if not np.isnan(prev_expected_cov):
- self.assertAlmostEqual(
- prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
- expected_cov = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_cov, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_cov,
+ sess.run(cov, feed_dict=feed_dict), 5)
+ expected_cov = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])[0, 1]
+ self.assertAlmostEqual(expected_cov,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict),
+ 5)
prev_expected_cov = expected_cov
@@ -5574,8 +5503,8 @@ class StreamingPearsonRTest(test.TestCase):
def testVars(self):
metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]))
_assert_metric_variables(self, (
'pearson_r/covariance/comoment:0',
@@ -5595,8 +5524,8 @@ class StreamingPearsonRTest(test.TestCase):
def testMetricsCollection(self):
my_collection_name = '__metrics__'
pearson_r, _ = metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [pearson_r])
@@ -5604,8 +5533,8 @@ class StreamingPearsonRTest(test.TestCase):
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -5613,8 +5542,8 @@ class StreamingPearsonRTest(test.TestCase):
def testValueTensorIsIdempotent(self):
labels = random_ops.random_normal((10, 3), seed=2)
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5633,8 +5562,8 @@ class StreamingPearsonRTest(test.TestCase):
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
expected_r = np.corrcoef(np.arange(10), np.arange(10))[0, 1]
sess.run(variables.local_variables_initializer())
@@ -5648,8 +5577,8 @@ class StreamingPearsonRTest(test.TestCase):
labels = constant_op.constant(
[1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
expected_r = np.corrcoef([2, 4, 6], [1, 3, 2])[0, 1]
sess.run(variables.local_variables_initializer())
@@ -5698,17 +5627,18 @@ class StreamingPearsonRTest(test.TestCase):
predictions_t: predictions[stride * i:stride * (i + 1)],
labels_t: labels[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(prev_expected_r):
- self.assertAlmostEqual(
- prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
expected_r = np.corrcoef(predictions[:stride * (i + 1)],
labels[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_r, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
@@ -5736,19 +5666,21 @@ class StreamingPearsonRTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(prev_expected_r):
- self.assertAlmostEqual(
- prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
- cmat = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])
+ self.assertAlmostEqual(prev_expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
+ cmat = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])
expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1])
- self.assertAlmostEqual(
- expected_r, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
@@ -5758,7 +5690,7 @@ class StreamingPearsonRTest(test.TestCase):
predictions = np.random.randn(n)
labels = 0.5 * predictions + np.random.randn(n)
stride = 10
- weights = (np.arange(n).reshape(n//stride, stride) % stride == 0)
+ weights = (np.arange(n).reshape(n // stride, stride) % stride == 0)
for row in weights:
np.random.shuffle(row)
# Now, weights is one-hot by row - one item per batch has non-zero weight.
@@ -5778,19 +5710,20 @@ class StreamingPearsonRTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- cmat = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])
+ cmat = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])
expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1])
actual_r = sess.run(update_op, feed_dict=feed_dict)
self.assertEqual(np.isnan(expected_r), np.isnan(actual_r))
- self.assertEqual(np.isnan(expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(expected_r):
- self.assertAlmostEqual(
- expected_r, actual_r, 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r, actual_r, 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
class StreamingMeanCosineDistanceTest(test.TestCase):
@@ -6191,20 +6124,14 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertAlmostEqual(desired_output, miou.eval())
def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
- predictions = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[5]), constant_op.constant(
- 1, shape=[5])
- ],
- 0)
- labels = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[3]), constant_op.constant(
- 1, shape=[7])
- ],
- 0)
+ predictions = array_ops.concat([
+ constant_op.constant(0, shape=[5]),
+ constant_op.constant(1, shape=[5])
+ ], 0)
+ labels = array_ops.concat([
+ constant_op.constant(0, shape=[3]),
+ constant_op.constant(1, shape=[7])
+ ], 0)
num_classes = 2
with self.test_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
@@ -6238,29 +6165,20 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertEqual(0., miou.eval())
def testResultsWithSomeMissing(self):
- predictions = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[5]), constant_op.constant(
- 1, shape=[5])
- ],
- 0)
- labels = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[3]), constant_op.constant(
- 1, shape=[7])
- ],
- 0)
+ predictions = array_ops.concat([
+ constant_op.constant(0, shape=[5]),
+ constant_op.constant(1, shape=[5])
+ ], 0)
+ labels = array_ops.concat([
+ constant_op.constant(0, shape=[3]),
+ constant_op.constant(1, shape=[7])
+ ], 0)
num_classes = 2
- weights = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[1]), constant_op.constant(
- 1, shape=[8]), constant_op.constant(
- 0, shape=[1])
- ],
- 0)
+ weights = array_ops.concat([
+ constant_op.constant(0, shape=[1]),
+ constant_op.constant(1, shape=[8]),
+ constant_op.constant(0, shape=[1])
+ ], 0)
with self.test_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
@@ -6270,56 +6188,45 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertAlmostEqual(desired_miou, miou.eval())
def testMissingClassInLabels(self):
- labels = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 0, 0, 0, 0, 1]],
- [[1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0]]])
- predictions = constant_op.constant([
- [[0, 0, 2, 1, 1, 0],
- [0, 1, 2, 2, 0, 1]],
- [[0, 0, 2, 1, 1, 1],
- [1, 1, 2, 0, 0, 0]]])
+ labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant(
+ [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
+ [1, 1, 2, 0, 0, 0]]])
num_classes = 3
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
- self.assertAlmostEqual(
- 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
- miou.eval())
+ self.assertAlmostEqual(1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 /
+ (0 + 5 + 0)), miou.eval())
def testMissingClassOverallSmall(self):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
self.assertAlmostEqual(1, miou.eval())
def testMissingClassOverallLarge(self):
- labels = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 0, 0, 0, 0, 1]],
- [[1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0]]])
- predictions = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 1, 0, 0, 1, 1]],
- [[0, 0, 0, 1, 1, 1],
- [1, 1, 1, 0, 0, 0]]])
+ labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant(
+ [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
+ [1, 1, 1, 0, 0, 0]]])
num_classes = 3
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
- self.assertAlmostEqual(
- 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval())
+ self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)),
+ miou.eval())
class StreamingConcatTest(test.TestCase):
@@ -6683,7 +6590,8 @@ class CohenKappaTest(test.TestCase):
_assert_metric_variables(self, (
'cohen_kappa/po:0',
'cohen_kappa/pe_row:0',
- 'cohen_kappa/pe_col:0',))
+ 'cohen_kappa/pe_col:0',
+ ))
def testMetricsCollection(self):
my_collection_name = '__metrics__'
@@ -6705,9 +6613,9 @@ class CohenKappaTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
- (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
with self.test_session() as sess:
@@ -6723,10 +6631,7 @@ class CohenKappaTest(test.TestCase):
self.assertAlmostEqual(initial_kappa, kappa.eval(), 5)
def testBasic(self):
- confusion_matrix = np.array([
- [9, 3, 1],
- [4, 8, 2],
- [2, 1, 6]])
+ confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]])
# overall total = 36
# po = [9, 8, 6], sum(po) = 23
# pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25]
@@ -6738,8 +6643,10 @@ class CohenKappaTest(test.TestCase):
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64]
- shapes = [(len(labels,)), # 1-dim
- (len(labels), 1)] # 2-dim
+ shapes = [
+ (len(labels,)), # 1-dim
+ (len(labels), 1)
+ ] # 2-dim
weights = [None, np.ones_like(labels)]
for dtype in dtypes:
@@ -6795,10 +6702,7 @@ class CohenKappaTest(test.TestCase):
self.assertAlmostEqual(expect, kappa.eval(), 5)
def testWeighted(self):
- confusion_matrix = np.array([
- [9, 3, 1],
- [4, 8, 2],
- [2, 1, 6]])
+ confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]])
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
num_samples = np.sum(confusion_matrix, dtype=np.int32)
weights = (np.arange(0, num_samples) % 5) / 5.0
@@ -6809,31 +6713,26 @@ class CohenKappaTest(test.TestCase):
with self.test_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
- kappa, update_op = metrics.cohen_kappa(labels, predictions, 4,
- weights=weights)
+ kappa, update_op = metrics.cohen_kappa(
+ labels, predictions, 4, weights=weights)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(expect, sess.run(update_op), 5)
self.assertAlmostEqual(expect, kappa.eval(), 5)
def testWithMultipleUpdates(self):
- confusion_matrix = np.array([
- [90, 30, 10, 20],
- [40, 80, 20, 30],
- [20, 10, 60, 35],
- [15, 25, 30, 25]])
+ confusion_matrix = np.array([[90, 30, 10, 20], [40, 80, 20, 30],
+ [20, 10, 60, 35], [15, 25, 30, 25]])
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
num_samples = np.sum(confusion_matrix, dtype=np.int32)
weights = (np.arange(0, num_samples) % 5) / 5.0
num_classes = confusion_matrix.shape[0]
batch_size = num_samples // 10
- predictions_t = array_ops.placeholder(dtypes_lib.float32,
- shape=(batch_size,))
- labels_t = array_ops.placeholder(dtypes_lib.int32,
- shape=(batch_size,))
- weights_t = array_ops.placeholder(dtypes_lib.float32,
- shape=(batch_size,))
+ predictions_t = array_ops.placeholder(
+ dtypes_lib.float32, shape=(batch_size,))
+ labels_t = array_ops.placeholder(dtypes_lib.int32, shape=(batch_size,))
+ weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
with self.test_session() as sess:
@@ -6841,10 +6740,13 @@ class CohenKappaTest(test.TestCase):
for idx in range(0, num_samples, batch_size):
batch_start, batch_end = idx, idx + batch_size
- sess.run(update_op,
- feed_dict={labels_t: labels[batch_start:batch_end],
- predictions_t: predictions[batch_start:batch_end],
- weights_t: weights[batch_start:batch_end]})
+ sess.run(
+ update_op,
+ feed_dict={
+ labels_t: labels[batch_start:batch_end],
+ predictions_t: predictions[batch_start:batch_end],
+ weights_t: weights[batch_start:batch_end]
+ })
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
# labels_np, predictions_np, sample_weight=weights_np)
expect = 0.289965397924
@@ -6862,7 +6764,8 @@ class CohenKappaTest(test.TestCase):
with self.assertRaises(ValueError):
metrics.cohen_kappa(invalid_labels, predictions, 3)
- invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2))
+ invalid_predictions = array_ops.placeholder(
+ dtypes_lib.float32, shape=(4, 2))
labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
with self.assertRaises(ValueError):
metrics.cohen_kappa(labels, invalid_predictions, 3)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
index 0d1de869f6..73dd56398c 100644
--- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
@@ -54,10 +54,10 @@ BATCH_SIZE = 128
DATA_DIR = '/tmp/cifar10_data'
# Constants describing the training process.
-MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
-NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
+MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
+NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
-INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
+INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
@@ -82,8 +82,7 @@ def _activation_summary(x):
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
- tf.summary.scalar(tensor_name + '/sparsity',
- tf.nn.zero_fraction(x))
+ tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
def _variable_on_cpu(name, shape, initializer):
@@ -120,10 +119,9 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
Variable Tensor
"""
dtype = tf.float32
- var = _variable_on_cpu(
- name,
- shape,
- tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
+ var = _variable_on_cpu(name, shape,
+ tf.truncated_normal_initializer(
+ stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
@@ -188,10 +186,8 @@ def inference(images):
# Note that the masks are applied only to the weight tensors
# conv1
with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 3, 64],
- stddev=5e-2,
- wd=0.0)
+ kernel = _variable_with_weight_decay(
+ 'weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)
conv = tf.nn.conv2d(
images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
@@ -201,18 +197,20 @@ def inference(images):
_activation_summary(conv1)
# pool1
- pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
- padding='SAME', name='pool1')
+ pool1 = tf.nn.max_pool(
+ conv1,
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME',
+ name='pool1')
# norm1
- norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
- name='norm1')
+ norm1 = tf.nn.lrn(
+ pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
# conv2
with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 64, 64],
- stddev=5e-2,
- wd=0.0)
+ kernel = _variable_with_weight_decay(
+ 'weights', shape=[5, 5, 64, 64], stddev=5e-2, wd=0.0)
conv = tf.nn.conv2d(
norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
@@ -221,19 +219,23 @@ def inference(images):
_activation_summary(conv2)
# norm2
- norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
- name='norm2')
+ norm2 = tf.nn.lrn(
+ conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
# pool2
- pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
- strides=[1, 2, 2, 1], padding='SAME', name='pool2')
+ pool2 = tf.nn.max_pool(
+ norm2,
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME',
+ name='pool2')
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
dim = reshape.get_shape()[1].value
- weights = _variable_with_weight_decay('weights', shape=[dim, 384],
- stddev=0.04, wd=0.004)
+ weights = _variable_with_weight_decay(
+ 'weights', shape=[dim, 384], stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
local3 = tf.nn.relu(
tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
@@ -242,8 +244,8 @@ def inference(images):
# local4
with tf.variable_scope('local4') as scope:
- weights = _variable_with_weight_decay('weights', shape=[384, 192],
- stddev=0.04, wd=0.004)
+ weights = _variable_with_weight_decay(
+ 'weights', shape=[384, 192], stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
local4 = tf.nn.relu(
tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
@@ -255,8 +257,8 @@ def inference(images):
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
# and performs the softmax internally for efficiency.
with tf.variable_scope('softmax_linear') as scope:
- weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
- stddev=1/192.0, wd=0.0)
+ weights = _variable_with_weight_decay(
+ 'weights', [192, NUM_CLASSES], stddev=1 / 192.0, wd=0.0)
biases = _variable_on_cpu('biases', [NUM_CLASSES],
tf.constant_initializer(0.0))
softmax_linear = tf.add(
@@ -337,11 +339,12 @@ def train(total_loss, global_step):
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
- lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
- global_step,
- decay_steps,
- LEARNING_RATE_DECAY_FACTOR,
- staircase=True)
+ lr = tf.train.exponential_decay(
+ INITIAL_LEARNING_RATE,
+ global_step,
+ decay_steps,
+ LEARNING_RATE_DECAY_FACTOR,
+ staircase=True)
tf.summary.scalar('learning_rate', lr)
# Generate moving averages of all losses and associated summaries.
@@ -365,8 +368,8 @@ def train(total_loss, global_step):
tf.summary.histogram(var.op.name + '/gradients', grad)
# Track the moving averages of all trainable variables.
- variable_averages = tf.train.ExponentialMovingAverage(
- MOVING_AVERAGE_DECAY, global_step)
+ variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,
+ global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
@@ -383,10 +386,13 @@ def maybe_download_and_extract():
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
+
def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
- float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.write('\r>> Downloading %s %.1f%%' %
+ (filename,
+ float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
+
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD
index d9d55faf50..23f90cf77e 100644
--- a/tensorflow/contrib/mpi/BUILD
+++ b/tensorflow/contrib/mpi/BUILD
@@ -71,6 +71,8 @@ cc_library(
"//tensorflow/core:protos_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:recent_request_ids",
+ "//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 1a2563d20f..8d14a3ef04 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -33,8 +33,10 @@ limitations under the License.
namespace tensorflow {
MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env)
- : BaseRendezvousMgr(env), worker_env_2(env), use_optimal_transfer_(false) {
-
+ : BaseRendezvousMgr(env),
+ worker_env_2(env),
+ use_optimal_transfer_(false),
+ recv_tensor_recent_request_ids_(100000) {
const char* mpienv = getenv("MPI_OPTIMAL_PATH");
if (mpienv && mpienv[0] == '1') {
LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when "
@@ -149,6 +151,8 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
*/
void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
const int mpi_dst) {
+ TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique(
+ req.request_id(), "RecvTensor (MPIRendezvousMgr)", req));
const int64 step_id = request.step_id();
const std::string& key = request.rendezvous_key();
Rendezvous::ParsedKey parsed;
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index b15748d63c..ca42ee2f6d 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -30,10 +30,11 @@ limitations under the License.
#include <iostream>
+#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
#include "tensorflow/contrib/mpi/mpi_utils.h"
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
-#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#define TAG_REQTENSOR 1010
@@ -104,6 +105,7 @@ class MPIRequestTensorCall {
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) {
req_.set_step_id(step_id);
req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size());
+ req_.set_request_id(GetUniqueRequestId());
request_buffer_size_ = req_.ByteSize();
// request_buffer_ = new char[request_buffer_size_];
// req_.SerializeToArray(request_buffer_, request_buffer_size_);
@@ -177,6 +179,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_
GUARDED_BY(mrq_);
+ RecentRequestIds recv_tensor_recent_request_ids_;
+
void AddRequest(RecvTensorRequest, const int);
void MPIBackgroundThread();
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index cb1719c3be..bb219e0edc 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#define TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
#ifdef GOOGLE_CUDA
@@ -136,4 +136,4 @@ class NcclManager {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_NCCL_COMMUNICATOR_H_
diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py
index d3c3531f40..b24e332e4a 100644
--- a/tensorflow/contrib/ndlstm/python/lstm1d.py
+++ b/tensorflow/contrib/ndlstm/python/lstm1d.py
@@ -22,7 +22,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
@@ -85,18 +84,11 @@ def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False):
Output sequence (length, batch_size, noutput)
"""
with variable_scope.variable_scope(scope, "SeqLstm", [inputs]):
- # TODO(tmb) make batch size, sequence_length dynamic
- # example: sequence_length = tf.shape(inputs)[0]
- _, batch_size, _ = _shape(inputs)
- lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False)
- state = array_ops.zeros([batch_size, lstm_cell.state_size])
- sequence_length = int(inputs.get_shape()[0])
- sequence_lengths = math_ops.to_int64(
- array_ops.fill([batch_size], sequence_length))
+ lstm_cell = rnn_cell.BasicLSTMCell(noutput)
if reverse:
inputs = array_ops.reverse_v2(inputs, [0])
outputs, _ = rnn.dynamic_rnn(
- lstm_cell, inputs, sequence_lengths, state, time_major=True)
+ lstm_cell, inputs, time_major=True, dtype=inputs.dtype)
if reverse:
outputs = array_ops.reverse_v2(outputs, [0])
return outputs
diff --git a/tensorflow/contrib/nearest_neighbor/kernels/heap.h b/tensorflow/contrib/nearest_neighbor/kernels/heap.h
index 6e33a574e2..32925569a8 100644
--- a/tensorflow/contrib/nearest_neighbor/kernels/heap.h
+++ b/tensorflow/contrib/nearest_neighbor/kernels/heap.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
+#ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
+#define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
#include <cassert>
#include <cstdint>
@@ -205,4 +205,4 @@ class AugmentedHeap : public HeapBase<KeyType, DataType> {
} // namespace nearest_neighbor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
+#endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
diff --git a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h
index 1670e2f83b..c53205e1a4 100644
--- a/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h
+++ b/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
+#ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
+#define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -232,4 +232,4 @@ class HyperplaneMultiprobe {
} // namespace nearest_neighbor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
+#endif // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HYPERPLANE_LSH_PROBES_H_
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 9c961f2b9c..827279bd47 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -19,6 +19,7 @@ py_library(
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
"python/training/lazy_adam_optimizer.py",
+ "python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
"python/training/multitask_optimizer_wrapper.py",
"python/training/nadam_optimizer.py",
@@ -193,6 +194,27 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "model_average_optimizer_test",
+ srcs = ["python/training/model_average_optimizer_test.py"],
+ additional_deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//third_party/py/numpy",
+ ],
+ tags = [
+ "notap", # This test launches local server.
+ ],
+)
+
py_test(
name = "sign_decay_test",
srcs = ["python/training/sign_decay_test.py"],
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 90d2f92462..6c1bb1adc0 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
+from tensorflow.contrib.opt.python.training.model_average_optimizer import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -48,7 +49,9 @@ _allowed_symbols = [
'MultitaskOptimizerWrapper',
'clip_gradients_by_global_norm',
'ElasticAverageOptimizer',
- 'ElasticAverageCustomGetter'
+ 'ElasticAverageCustomGetter',
+ 'ModelAverageOptimizer',
+ 'ModelAverageCustomGetter'
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 6132cba1f5..716ee9cdf7 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Wrapper optimizer for Elastic Average SGD """
from __future__ import absolute_import
from __future__ import division
@@ -78,23 +77,24 @@ class ElasticAverageCustomGetter(object):
def __call__(self, getter, name, trainable, collections, *args, **kwargs):
if trainable:
with ops.device(self._worker_device):
- local_var = getter(name, trainable=True,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- *args, **kwargs)
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
global_center_variable = variable_scope.variable(
- name='%s/%s' %
- (GLOBAL_VARIABLE_NAME,
- name),
- initial_value=local_var.initialized_value(),
- trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
with ops.device(self._worker_device):
local_center_variable = variable_scope.variable(
- name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
- trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
self._local_map[local_var] = local_center_variable
self._global_map[local_var] = global_center_variable
@@ -117,16 +117,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
# Default value as paper described
BETA = 0.9
- def __init__(
- self,
- opt,
- num_worker,
- ea_custom_getter,
- communication_period=10,
- moving_rate=None,
- rho=None,
- use_locking=True,
- name="ElasticAverageOptimizer"):
+ def __init__(self,
+ opt,
+ num_worker,
+ ea_custom_getter,
+ communication_period=10,
+ moving_rate=None,
+ rho=None,
+ use_locking=True,
+ name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
Args:
@@ -160,13 +159,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._rho = rho
self._local_step = variable_scope.get_variable(
- initializer=0,
- trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- name="local_step")
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name='local_step')
self._opt._prepare()
- def compute_gradients(self, loss, var_list=None,
+ def compute_gradients(self,
+ loss,
+ var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
@@ -204,16 +205,18 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
if not var_list:
var_list = variables.trainable_variables()
- elastic_difference = [math_ops.subtract(v, lv) for v, lv in zip(
- variables.trainable_variables(),
- [self._local_map[var] for var in var_list])]
+ elastic_difference = [
+ math_ops.subtract(v, lv)
+ for v, lv in zip(variables.trainable_variables(),
+ [self._local_map[var] for var in var_list])
+ ]
distance_loss = self._rho * math_ops.add_n(
- [gen_nn_ops.l2_loss(ed) for ed in elastic_difference])
+ [gen_nn_ops.l2_loss(ed) for ed in elastic_difference])
total_loss = loss + distance_loss
- return self._opt.compute_gradients(total_loss, var_list,
- gate_gradients, aggregation_method,
+ return self._opt.compute_gradients(total_loss, var_list, gate_gradients,
+ aggregation_method,
colocate_gradients_with_ops, grad_loss)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
@@ -241,7 +244,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
apply_updates = self._opt.apply_gradients(grads_and_vars)
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
- self._local_step, 1, name='local_step_update').op
+ self._local_step, 1, name='local_step_update').op
# update global variables.
def _Update_global_variables():
@@ -259,12 +262,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
differences.append(math_ops.subtract(v, lv))
for lvar, diff in zip(local_vars, differences):
with ops.device(lvar.device):
- update_ops.append(state_ops.assign_sub(lvar, math_ops.multiply(
- self._moving_rate, diff)))
+ update_ops.append(
+ state_ops.assign_sub(lvar,
+ math_ops.multiply(self._moving_rate,
+ diff)))
for var, diff in zip(global_center_vars, differences):
with ops.device(var.device):
- update_ops.append(state_ops.assign_add(var, math_ops.multiply(
- self._moving_rate, diff)))
+ update_ops.append(
+ state_ops.assign_add(var,
+ math_ops.multiply(self._moving_rate,
+ diff)))
if global_step:
with ops.colocate_with(global_step):
update_ops.append(state_ops.assign_add(global_step, 1))
@@ -272,10 +279,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
return variable_update
with ops.control_dependencies([local_update]):
- condition = math_ops.equal(math_ops.mod(
- self._local_step, self._period), 0)
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._period), 0)
conditional_update = control_flow_ops.cond(
- condition, _Update_global_variables, control_flow_ops.no_op)
+ condition, _Update_global_variables, control_flow_ops.no_op)
return conditional_update
def get_init_op(self, task_index):
@@ -285,10 +292,12 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
def _Add_sync_queues_and_barrier(enqueue_after_list):
"""Adds ops to enqueu on all worker queues"""
sync_queues = [
- data_flow_ops.FIFOQueue(self._num_worker, [dtypes.bool], shapes=[[]],
- shared_name='%s%s' % (
- 'variable_init_sync_queue', i)) for i in
- range(self._num_worker)]
+ data_flow_ops.FIFOQueue(
+ self._num_worker, [dtypes.bool],
+ shapes=[[]],
+ shared_name='%s%s' % ('variable_init_sync_queue', i))
+ for i in range(self._num_worker)
+ ]
queue_ops = []
# For each other worker, add an entry in a queue
token = constant_op.constant(False)
@@ -299,7 +308,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
else:
queue_ops.append(q.enqueue(token))
queue_ops.append(
- sync_queues[task_index].dequeue_many(len(sync_queues) - 1))
+ sync_queues[task_index].dequeue_many(len(sync_queues) - 1))
return control_flow_ops.group(*queue_ops)
init_ops = []
@@ -307,11 +316,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
global_center_vars = [self._global_map[var] for var in local_vars]
local_center_vars = [self._local_map[var] for var in local_vars]
if not (local_vars and global_center_vars and local_center_vars):
- raise ValueError(
- 'The lists of local_variables, global_center_variables, '
- 'local_center_variables should not be empty ')
- for lvar, gc_var, lc_var in zip(
- local_vars, global_center_vars, local_center_vars):
+ raise ValueError('The lists of local_variables, global_center_variables, '
+ 'local_center_variables should not be empty ')
+ for lvar, gc_var, lc_var in zip(local_vars, global_center_vars,
+ local_center_vars):
init_ops.append(state_ops.assign(lvar, gc_var))
init_ops.append(state_ops.assign(lc_var, gc_var))
@@ -325,6 +333,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
+
def __init__(self, ea_optimizer, is_chief, task_index):
"""Creates hook to handle ElasticAverageOptimizer initialization ops.
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 446e91018d..37539b9599 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -38,20 +38,20 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
}
cs = server_lib.ClusterSpec(cluster_dict)
workers = [
- server_lib.Server(
- cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_workers)
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
]
ps_servers = [
- server_lib.Server(
- cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_ps)
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
]
return cluster_dict, workers, ps_servers
@@ -68,15 +68,14 @@ def _get_workers(num_workers, period, workers, moving_rate):
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(
- worker_device=worker_device)
- with variable_scope.variable_scope('',
- custom_getter=ea_coustom), ops.device(
- device_setter.replica_device_setter(worker_device=worker_device,
- ps_device="/job:ps/task:0/cpu:0",
- ps_tasks=1)):
- global_step = variables.Variable(0, name='global_step',
- trainable=False)
+ ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=ea_coustom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)):
+ global_step = variables.Variable(0, name="global_step", trainable=False)
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
@@ -86,21 +85,19 @@ def _get_workers(num_workers, period, workers, moving_rate):
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
opt = ElasticAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- moving_rate=moving_rate,
- communication_period=period,
- ea_custom_getter=ea_coustom
- )
+ opt=sgd_opt,
+ num_worker=num_workers,
+ moving_rate=moving_rate,
+ communication_period=period,
+ ea_custom_getter=ea_coustom)
train_op = [
- opt.apply_gradients(
- ([grads_0, var_0],
- [grads_1, var_1]), global_step)
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
# Creates MonitoredSession
- sess = training.MonitoredTrainingSession(workers[worker_id].target,
- hooks=[easgd_hook])
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[easgd_hook])
sessions.append(sess)
graphs.append(graph)
@@ -110,6 +107,7 @@ def _get_workers(num_workers, period, workers, moving_rate):
class ElasticAverageOptimizerTest(test.TestCase):
+
def _run(self, train_op, sess):
sess.run(train_op)
@@ -117,15 +115,14 @@ class ElasticAverageOptimizerTest(test.TestCase):
num_workers = 1
communication_period = 2
num_ps = 1
- cluster, workers, _ = create_local_cluster(num_workers=num_workers,
- num_ps=num_ps)
+ cluster, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(num_workers,
- communication_period,
- workers, 1.0)
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, 1.0)
- var_0 = graphs[0].get_tensor_by_name('v0:0')
- var_1 = graphs[0].get_tensor_by_name('v1:0')
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
global_step = training_util.get_global_step(graphs[0])
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
@@ -166,18 +163,17 @@ class ElasticAverageOptimizerTest(test.TestCase):
num_workers = 2
communication_period = 1
num_ps = 2
- cluster, workers, _ = create_local_cluster(num_workers=num_workers,
- num_ps=num_ps)
+ cluster, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(num_workers,
- communication_period,
- workers, 0.5)
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, 0.5)
- var_0 = graphs[0].get_tensor_by_name('v0:0')
- var_1 = graphs[0].get_tensor_by_name('v1:0')
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
- var_0_1 = graphs[1].get_tensor_by_name('v0:0')
- var_1_1 = graphs[1].get_tensor_by_name('v1:0')
+ var_0_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_1 = graphs[1].get_tensor_by_name("v1:0")
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
@@ -201,25 +197,24 @@ class ElasticAverageOptimizerTest(test.TestCase):
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
- "ps": ["ps0:2222", "ps1:2222"],
- "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(
- worker_device="/job:worker/task:0")
+ ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope('', custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_coustom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
- w = variable_scope.get_variable(initializer=[2, 1], name='w')
- v_g, w_g = ea_coustom._global_map[v],ea_coustom._global_map[w]
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
+ v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)
self.assertDeviceEqual("job:ps/task:1", w_g.device)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 4c3fec0672..aeca900bc8 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -47,8 +47,9 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
"""
def _apply_sparse(self, grad, var):
- beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
new file mode 100644
index 0000000000..a7c97a1da2
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -0,0 +1,308 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrapper optimizer for Model Average."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import session_run_hook
+
+GLOBAL_VARIABLE_NAME = "global_center_variable"
+
+
+class ModelAverageCustomGetter(object):
+ """Custom_getter class is used to do.
+
+ 1. Change trainable variables to local collection and place them at worker
+ device
+ 2. Generate global variables
+ Notice that the class should be used with tf.replica_device_setter,
+ so that the global center variables and global step variable can be placed
+ at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to
+ use this custom getter.
+
+ For example,
+ ma_custom_getter = ModelAverageCustomGetter(worker_device)
+ with tf.device(
+ tf.train.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/cpu:0",
+ cluster=cluster)),
+ tf.variable_scope('',custom_getter=ma_custom_getter):
+ hid_w = tf.get_variable(
+ initializer=tf.truncated_normal(
+ [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
+ stddev=1.0 / IMAGE_PIXELS),
+ name="hid_w")
+ hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
+ name="hid_b")
+ """
+
+ def __init__(self, worker_device):
+ """Create a new `ElasticAverageCustomGetter`.
+
+ Args:
+ worker_device: String. Name of the `worker` job.
+ """
+ self._worker_device = worker_device
+ self._local_2_global = {}
+
+ def __call__(self, getter, name, trainable, collections, *args, **kwargs):
+ if trainable:
+ with ops.device(self._worker_device):
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+
+ global_variable = variable_scope.variable(
+ name="%s/%s" % (GLOBAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+
+ self._local_2_global[local_var] = global_variable
+ return local_var
+ else:
+ return getter(name, trainable, collections, *args, **kwargs)
+
+
+class ModelAverageOptimizer(optimizer.Optimizer):
+ """Wrapper optimizer that implements the Model Average algorithm.
+
+ This is a sync optimizer. During the training, each worker will update
+ the local variables and maintains its own local_step, which starts from 0
+ and is incremented by 1 after each update of local variables. Whenever the
+ interval_steps divides the local step, the local variables from all the
+ workers will be averaged and assigned to global center variables. Then the
+ local variables will be assigned by global center variables.
+ """
+
+ def __init__(self,
+ opt,
+ num_worker,
+ is_chief,
+ ma_custom_getter,
+ interval_steps=100,
+ use_locking=True,
+ name="ModelAverageOptimizer"):
+ """Construct a new model average optimizer.
+
+ Args:
+ opt: The actual optimizer that will be used to update local variables
+ num_worker: The number of workers
+ is_chief: whether chief worker
+ ma_custom_getter: ModelAverageCustomGetter
+ interval_steps: An int point value to controls the frequency of the
+ average of local variables
+ use_locking: If True use locks for update operations
+ name: string. Optional name of the returned operation
+ """
+ super(ModelAverageOptimizer, self).__init__(use_locking, name)
+ self._opt = opt
+ self._num_worker = num_worker
+ self._is_chief = is_chief
+ self._local_2_global = ma_custom_getter._local_2_global # pylint:disable=protected-access
+ self._interval_steps = interval_steps
+ self._accumulator_list = []
+ self._chief_init_op = None
+
+ self._local_step = variable_scope.get_variable(
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name="local_step")
+
+ self._opt._prepare() # pylint:disable=protected-access
+
+ def compute_gradients(self, *args, **kwargs):
+ """Compute gradients of "loss" for the variables in "var_list".
+
+ This simply wraps the compute_gradients() from the real optimizer.
+
+ Args:
+ *args: Arguments for compute_gradients().
+ **kwargs: Keyword arguments for compute_gradients().
+
+ Returns:
+ A list of (gradient, variable) pairs.
+ """
+ return self._opt.compute_gradients(*args, **kwargs)
+
+ def _local_vars_update(self, var_list):
+ """Get the update ops for the local variables in "var_list".
+
+ Args:
+ var_list: Optional list or tuple of 'tf.Variable' to update
+
+ Returns:
+ An update op
+
+ Raises:
+ ValueError: if var_list is empty.
+ """
+ if not var_list:
+ raise ValueError("The list of local_variables should not be empty")
+ update_ops = []
+ global_center_vars = [self._local_2_global[var] for var in var_list]
+ for lvar, gvar in zip(var_list, global_center_vars):
+ with ops.device(lvar.device):
+ update_ops.append(state_ops.assign(lvar, gvar.read_value()))
+ return control_flow_ops.group(*(update_ops))
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This contains most of the synchronization implementation and also wraps the
+ apply_gradients() from the real optimizer. The chief work updates global
+ variables.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ compute_gradients().
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the Optimizer constructor.
+
+ Returns:
+ A conditional 'Operation' that update both local and global variables or
+ just local variables
+
+ Raises:
+ ValueError: If the grads_and_vars is empty.
+ ValueError: If global step is not provided, the staleness cannot be
+ checked.
+ """
+
+ # update local variables
+ if not grads_and_vars:
+ raise ValueError("Must supply at least one variable")
+ if global_step is None:
+ raise ValueError("Global step is required")
+
+ apply_updates = self._opt.apply_gradients(grads_and_vars)
+ with ops.control_dependencies([apply_updates]):
+ local_update = state_ops.assign_add(
+ self._local_step, 1, name="local_step_update").op
+
+ # update global variables.
+ def _update_global_variables(): # pylint: disable=missing-docstring
+ local_vars = [v for g, v in grads_and_vars if g is not None]
+ global_vars = [self._local_2_global[v] for v in local_vars]
+ # sync queue
+ with ops.colocate_with(global_step):
+ sync_queue = data_flow_ops.FIFOQueue(
+ -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue")
+ train_ops = []
+ aggregated_vars = []
+ with ops.name_scope(None, self._name + "/global"):
+ for var, gvar in zip(local_vars, global_vars):
+ # pylint: disable=protected-access
+ with ops.device(gvar.device):
+ if isinstance(var._ref(), ops.Tensor):
+ var_accum = data_flow_ops.ConditionalAccumulator(
+ var.dtype,
+ shape=var.get_shape(),
+ shared_name=gvar.name + "/var_accum")
+ train_ops.append(
+ var_accum.apply_grad(var._ref(), local_step=global_step))
+ aggregated_vars.append(var_accum.take_grad(self._num_worker))
+ else:
+ raise ValueError("Unknown local variable type!")
+ self._accumulator_list.append((var_accum, gvar.device))
+ # chief worker updates global vars and enqueues tokens to the sync queue
+ if self._is_chief:
+ update_ops = []
+ with ops.control_dependencies(train_ops):
+ for avg_var, gvar in zip(aggregated_vars, global_vars):
+ with ops.device(gvar.device):
+ update_ops.append(state_ops.assign(gvar, avg_var))
+ with ops.device(global_step.device):
+ update_ops.append(state_ops.assign_add(global_step, 1))
+ with ops.control_dependencies(update_ops), ops.device(
+ global_step.device):
+ tokens = array_ops.fill([self._num_worker - 1],
+ constant_op.constant(False))
+ sync_op = sync_queue.enqueue_many(tokens)
+ else:
+ with ops.control_dependencies(train_ops), ops.device(
+ global_step.device):
+ sync_op = sync_queue.dequeue()
+
+ with ops.control_dependencies([sync_op]):
+ local_update_op = self._local_vars_update(local_vars)
+ return local_update_op
+
+ with ops.control_dependencies([local_update]):
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._interval_steps), 0)
+ conditional_update = control_flow_ops.cond(
+ condition, _update_global_variables, control_flow_ops.no_op)
+
+ chief_init_ops = []
+ for accum, dev in self._accumulator_list:
+ with ops.device(dev):
+ chief_init_ops.append(
+ accum.set_global_step(global_step, name="SetGlobalStep"))
+ self._chief_init_op = control_flow_ops.group(*(chief_init_ops))
+
+ return conditional_update
+
+ def get_init_op(self):
+ """Returns the op.
+
+ This method lets all the local variables equal to the global
+ variables before the training begins.
+ """
+ return self._local_vars_update(variables.trainable_variables())
+
+ def make_session_run_hook(self):
+ """Creates a hook to handle ModelAverage ops such as initialization."""
+ return _ModelAverageOptimizerHook(self, self._is_chief)
+
+
+class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook): # pylint: disable=missing-docstring
+
+ def __init__(self, ma_optimizer, is_chief):
+ """Creates hook to handle ModelAverageOptimizer initialization ops.
+
+ Args:
+ ma_optimizer: `ModelAverageOptimizer` which this hook will initialize.
+ is_chief: `Bool`, whether is this a chief replica or not.
+ """
+ self._ma_optimizer = ma_optimizer
+ self._is_chief = is_chief
+
+ def begin(self):
+ self._local_init_op = variables.local_variables_initializer()
+ self._global_init_op = None
+ if self._is_chief:
+ self._global_init_op = variables.global_variables_initializer()
+ self._chief_init_op = self._ma_optimizer._chief_init_op # pylint: disable=protected-access
+ self._variable_init_op = self._ma_optimizer.get_init_op()
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
new file mode 100644
index 0000000000..6cca0a8a00
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -0,0 +1,198 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ModelAverageOptimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import portpicker
+
+from tensorflow.contrib.opt.python.training import model_average_optimizer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import device_setter
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import server_lib
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+def create_local_cluster(num_workers, num_ps, protocol="grpc"):
+ """Create local GRPC servers and return them."""
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
+ ]
+
+ return cluster_dict, workers, ps_servers
+
+
+# Creates the workers and return their sessions, graphs, train_ops.
+# Cheif worker will update at last
+def _get_workers(num_workers, steps, workers):
+ sessions = []
+ graphs = []
+ train_ops = []
+ for worker_id in range(num_workers):
+ graph = ops.Graph()
+ is_chief = (worker_id == 0)
+ with graph.as_default():
+ worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
+ ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+ worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=ma_coustom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)):
+
+ global_step = variables.Variable(0, name="global_step", trainable=False)
+ var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
+
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ if worker_id == 0:
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ else:
+ grads_0 = constant_op.constant(-2.0)
+ grads_1 = constant_op.constant(-2.0)
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
+ train_op = [
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
+ ]
+ easgd_hook = opt.make_session_run_hook()
+ # Creates MonitoredSession
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[easgd_hook])
+
+ sessions.append(sess)
+ graphs.append(graph)
+ train_ops.append(train_op)
+ return sessions, graphs, train_ops
+
+
+class ModelAverageOptimizerTest(test.TestCase):
+ def _run(self, train_op, sess):
+ sess.run(train_op)
+
+ def test1Workers2Period(self):
+ num_workers = 2
+ steps = 2
+ num_ps = 1
+ _, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
+
+ sessions, graphs, train_ops = _get_workers(num_workers, steps, workers)
+
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
+ global_step = training_util.get_global_step(graphs[0])
+ global_var_0 = graphs[0].get_tensor_by_name(
+ model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ global_var_1 = graphs[0].get_tensor_by_name(
+ model_average_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
+
+ # Verify the initialized value.
+ self.assertAllEqual(0.0, sessions[0].run(var_0))
+ self.assertAllEqual(1.0, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[0].run(global_var_0))
+ self.assertAllEqual(1.0, sessions[0].run(global_var_1))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+
+ self.assertAllEqual(1.0, sessions[0].run(var_0))
+ self.assertAllEqual(2.0, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[0].run(global_var_0))
+ self.assertAllEqual(1.0, sessions[0].run(global_var_1))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+
+ # iteration 2, global varibale update
+ thread_0 = self.checkedThread(
+ target=self._run, args=(train_ops[0], sessions[0]))
+ thread_1 = self.checkedThread(
+ target=self._run, args=(train_ops[1], sessions[1]))
+ thread_0.start()
+ thread_1.start()
+ thread_0.join()
+ thread_1.join()
+
+ self.assertAllEqual(3.0, sessions[0].run(var_0))
+ self.assertAllEqual(4.0, sessions[0].run(var_1))
+ self.assertAllEqual(3.0, sessions[0].run(global_var_0))
+ self.assertAllEqual(4.0, sessions[0].run(global_var_1))
+ self.assertAllEqual(1, sessions[0].run(global_step))
+
+ # iteration 3
+ sessions[0].run(train_ops[0])
+
+ self.assertAllEqual(4.0, sessions[0].run(var_0))
+ self.assertAllEqual(5.0, sessions[0].run(var_1))
+ self.assertAllEqual(3.0, sessions[0].run(global_var_0))
+ self.assertAllEqual(4.0, sessions[0].run(global_var_1))
+ self.assertAllEqual(1, sessions[0].run(global_step))
+
+ def testPS2TasksWithClusterSpecClass(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ })
+ worker_device = "/job:worker/task:0"
+ ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
+ worker_device=worker_device)
+ from tensorflow.python.training import device_setter
+ with ops.device(
+ device_setter.replica_device_setter(cluster=cluster_spec,
+ worker_device=worker_device,
+ ps_device="/job:ps")), \
+ variable_scope.variable_scope("", custom_getter=ma_coustom):
+ v = variable_scope.get_variable(initializer=[1, 2], name="v")
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
+ v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w]
+ self.assertDeviceEqual("/job:worker/task:0", v.device)
+ self.assertDeviceEqual("job:ps/task:0", v_g.device)
+ self.assertDeviceEqual("/job:worker/task:0", w.device)
+ self.assertDeviceEqual("job:ps/task:1", w_g.device)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py
index a4421ecfe6..44a8890cb1 100644
--- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py
@@ -34,12 +34,13 @@ class NadamOptimizer(adam.AdamOptimizer):
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.apply_adam(
var,
m,
v,
- math_ops.cast(self._beta1_power, var.dtype.base_dtype),
- math_ops.cast(self._beta2_power, var.dtype.base_dtype),
+ math_ops.cast(beta1_power, var.dtype.base_dtype),
+ math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
@@ -51,12 +52,13 @@ class NadamOptimizer(adam.AdamOptimizer):
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle,
m.handle,
v.handle,
- math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
- math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
+ math_ops.cast(beta1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
@@ -66,8 +68,9 @@ class NadamOptimizer(adam.AdamOptimizer):
use_nesterov=True)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
- beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index 71582f9c9a..bd9078ae76 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -6,6 +6,7 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
+ "py_test",
"tf_gen_op_libs",
"tf_custom_op_library",
"tf_custom_op_py_library",
@@ -64,11 +65,28 @@ py_library(
"python/__init__.py",
],
srcs_version = "PY2AND3",
+ tags = [
+ "notap",
+ ],
deps = [
":periodic_resample_op_py",
],
)
+py_test(
+ name = "periodic_resample_op_test",
+ srcs = ["python/kernel_tests/periodic_resample_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "notap",
+ ],
+ deps = [
+ ":init_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
# py_library(
# name = "periodic_resample_op_py",
# srcs = ["python/ops/periodic_resample_op.py"],
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index bef21f7a5c..ba410f025d 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -100,6 +100,8 @@ template <class InputDataT,
desired_shape.size(), "."));
bool found = false;
+ const auto& input_tensor_shape = input_tensor.shape();
+
for (int i = 0; i < rank; ++i) {
// if (desired_shape(i) < 1) {
if (desired_shape[i] < 1) {
@@ -111,6 +113,15 @@ template <class InputDataT,
adjustable_dimension = i;
found = true;
} else {
+ OP_REQUIRES(
+ context, desired_shape[i] >= input_tensor_shape.dim_size(i),
+ tensorflow::errors::InvalidArgument(
+ "periodic_resample expects the size of non-adjustable "
+ "dimensions be at least as large as size of input tensor."
+ " Dimension ", i, " input tensor has size ",
+ input_tensor_shape.dim_size(i), ", desired shape has size ",
+ desired_shape[i], "."));
+
// target_dimensions[i] = desired_shape(i);
target_dimensions[i] = desired_shape[i];
new_sliced_size *= target_dimensions[i];
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index c90fc06c7f..82bd796956 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -34,26 +34,40 @@ This function implements a slightly more generic version of the subpixel
convolutions found in this [paper](https://arxiv.org/abs/1609.05158).
The formula for computing the elements in the `output` tensor is as follows:
+
`T` = `values` tensor of rank `R`
+
`S` = desired `shape` of output tensor (vector of length `R`)
+
`P` = `output` tensor of rank `R`
- \((T_1,\ldots,T_R)\) = shape(`T`)
- \([S_1,\ldots,S_q,\ldots,S_R]\) = elements of vector `S`
- A single element in `S` is left unspecified (denoted \(S_q=-1\)).
- Let \(f_i\) denote the (possibly non-integer) factor that relates the original
- dimension to the desired dimensions, \(S_i=f_i T_i\), for \(i\neq q\) where
- \(f_i>0\).
+ \\((T_1,\\ldots,T_R)\\) = shape(`T`)
+
+ \\([S_1,\\ldots,S_q,\\ldots,S_R]\\) = elements of vector `S`
+
+ A single element in `S` is left unspecified (denoted \\(S_q=-1\\)).
+
+ Let \\(f_i\\) denote the (possibly non-integer) factor that relates the original
+ dimension to the desired dimensions, \\(S_i=f_i T_i\\), for \\(i\\neq q\\) where
+ \\(f_i>0\\).
+
Define the following:
- \(g_i=\lceil f_i\rceil\)
- \(t=\prod_i T_i\)
- \(s=\prod_{i\neq q} S_i\)
- \(S_q\) can then be defined as by \(S_q=\lfloor t/s\rfloor\).
+
+ \\(g_i=\\lceil f_i\\rceil\\)
+
+ \\(t=\\prod_i T_i\\)
+
+ \\(s=\\prod_{i\\neq q} S_i\\)
+
+ \\(S_q\\) can then be defined by \\(S_q=\\lfloor t/s\\rfloor\\).
The elements of the resulting tensor are defined as
- \(P_{s_1,\ldots,s_R}=T_{h_1,\ldots,h_q,\ldots,h_R}\).
- The \(h_i\) (\(i\neq q\)) are defined by \(h_i=\lfloor s_i/g_i\rfloor\).
- \(h_q=S_q\sum_{j\neq q}^{q-1}G_j \mathrm{mod}(s_j,g_j) + s_q\), where
- \(G_j=\prod_{i}^{j-1}g_i\) (\(G_0=1\)).
+
+ \\(P_{s_1,\\ldots,s_R}=T_{h_1,\\ldots,h_q,\\ldots,h_R}\\).
+
+ The \\(h_i\\) (\\(i\\neq q\\)) are defined by \\(h_i=\\lfloor s_i/g_i\\rfloor\\).
+
+ \\(h_q=S_q\\sum_{j\\neq q}^{q-1}G_j \\mathrm{mod}(s_j,g_j) + s_q\\), where
+ \\(G_j=\\prod_{i}^{j-1}g_i\\) (\\(G_0=1\\)).
One drawback of this method is that whenever the output dimensions are slightly
less than integer multiples of the input dimensions, many of the tensor elements
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 1d727870f6..a25de55e18 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -19,8 +19,9 @@ from __future__ import division
from __future__ import print_function
import numpy
-import tensorflow
+
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -52,12 +53,11 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleBasic3D(self):
- input_tensor = numpy.arange(2*2*4).reshape((2, 2, 4))
+ input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4))
desired_shape = numpy.array([4, 4, None])
- output_tensor = numpy.array([[[0], [2], [4], [6]],
- [[1], [3], [5], [7]],
- [[8], [10], [12], [14]],
- [[9], [11], [13], [15]]])
+ output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]],
+ [[8], [10], [12], [14]], [[9], [11], [13],
+ [15]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
with self.test_session():
@@ -71,24 +71,18 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleBasic4D(self):
- input_tensor = numpy.arange(2*2*2*8).reshape((2, 2, 2, 8))
+ input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8))
desired_shape = numpy.array([4, 4, 4, None])
- output_tensor = numpy.array([[[[0], [4], [8], [12]],
- [[2], [6], [10], [14]],
- [[16], [20], [24], [28]],
- [[18], [22], [26], [30]]],
- [[[1], [5], [9], [13]],
- [[3], [7], [11], [15]],
- [[17], [21], [25], [29]],
- [[19], [23], [27], [31]]],
- [[[32], [36], [40], [44]],
- [[34], [38], [42], [46]],
- [[48], [52], [56], [60]],
- [[50], [54], [58], [62]]],
- [[[33], [37], [41], [45]],
- [[35], [39], [43], [47]],
- [[49], [53], [57], [61]],
- [[51], [55], [59], [63]]]])
+ output_tensor = numpy.array(
+ [[[[0], [4], [8], [12]], [[2], [6], [10], [14]],
+ [[16], [20], [24], [28]], [[18], [22], [26], [30]]],
+ [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25],
+ [29]],
+ [[19], [23], [27],
+ [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]],
+ [[48], [52], [56], [60]], [[50], [54], [58], [62]]],
+ [[[33], [37], [41], [45]], [[35], [39], [43], [47]],
+ [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
with self.test_session():
@@ -96,6 +90,19 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
+ def testPeriodicResampleErrors(self):
+ input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ 'Dimension 3 input tensor has size 4, desired shape has size 1'):
+ periodic_resample(input_tensor, [None, 4, 4, 1]).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ '4, to be the same as the length of the desired shape, 3'):
+ periodic_resample(input_tensor, [None, 4, 4]).eval()
+
-if __name__ == "__main__":
+if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
index e8443e718d..578d9424b2 100644
--- a/tensorflow/contrib/predictor/predictor_factories_test.py
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -50,8 +50,8 @@ class PredictorFactoriesTest(test.TestCase):
def testFromContribEstimator(self):
estimator = testing_common.get_arithmetic_estimator(core=False)
input_fn = testing_common.get_arithmetic_input_fn(core=False)
- predictor_factories.from_contrib_estimator(estimator, input_fn,
- output_alternative_key='sum')
+ predictor_factories.from_contrib_estimator(
+ estimator, input_fn, output_alternative_key='sum')
def testFromContribEstimatorWithCoreEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=True)
diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/py2tf/BUILD
index 7358822ef5..d395de986d 100644
--- a/tensorflow/contrib/py2tf/BUILD
+++ b/tensorflow/contrib/py2tf/BUILD
@@ -26,7 +26,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/py2tf/convert",
+ "//tensorflow/contrib/py2tf/converters",
"//tensorflow/contrib/py2tf/pyct",
"//tensorflow/contrib/py2tf/pyct/static_analysis",
"@gast_archive//:gast",
@@ -46,7 +46,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/py2tf/convert",
+ "//tensorflow/contrib/py2tf/converters",
"//tensorflow/contrib/py2tf/pyct",
"//tensorflow/contrib/py2tf/pyct/static_analysis",
"@gast_archive//:gast",
diff --git a/tensorflow/contrib/py2tf/api.py b/tensorflow/contrib/py2tf/api.py
index 3a36720969..ca1f4e2645 100644
--- a/tensorflow/contrib/py2tf/api.py
+++ b/tensorflow/contrib/py2tf/api.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from functools import wraps
+
import gast
import six
@@ -32,7 +34,115 @@ from tensorflow.python.util import tf_inspect
# (currently we require (module + class name, type))
-def to_graph(o, arg_value_hints=None):
+def graph_ready(f):
+ """No-op decorator that explicitly marks a function as graph-ready.
+
+ Graph-ready functions are assumed to not need any conversion.
+
+ Args:
+ f: Any callable.
+ Returns:
+ f itself.
+ """
+ setattr(f, '__pyct_is_compile_decorator', True)
+ return f
+
+
+def convert_inline(f, *args, **kwargs):
+ """Shorthand to convert and call a function.
+
+ For example, the following two statements are equivalent:
+
+ @convert()
+ def foo():
+ ...
+ foo(bar)
+
+ def foo():
+ ...
+ convert_inline(foo, bar)
+
+ Args:
+ f: Function to convert. Only this call will be converted.
+ *args: Passed through to f.
+ **kwargs: Passed through to f, with the following exceptions:
+ * arg_value_hints: A dict mapping parameter names to objects that can
+ hint at the type of those parameters.
+
+ Returns:
+ The result of the converted f applied to args and kwargs.
+ """
+ if 'arg_value_hints' in kwargs:
+ arg_value_hints = kwargs['arg_value_hints']
+ del kwargs['arg_value_hints']
+ else:
+ arg_value_hints = None
+ if tf_inspect.ismethod(f):
+ # When converting methods, the result is still an unbound function.
+ args = (f.__self__,) + args
+ return convert(arg_value_hints)(f)(*args, **kwargs)
+
+
+def convert(recursive=False, arg_types=None):
+ """Decorator that compiles a function to graph mode.
+
+ The decorator is dynamic - invoking compilation whenever the decorated function
+ is called. This means the parameter values are known at compilation.
+
+ Args:
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
+ arg_types: See to_graph.
+
+ Returns:
+ A decorator that compiles the given function to graph mode.
+
+ Raises:
+ ValueError: If any of the arguments are illegal.
+ """
+ if arg_types is None:
+ arg_types = {}
+
+ def decorator(f):
+ """Decorator implementation."""
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ """Wrapper that calls the compiled version of the wrapped function."""
+ partial_types = ()
+ arg_values = {}
+ arg_names = tf_inspect.getargspec(f)[0]
+ for name, arg in zip(arg_names, args):
+ arg_values[name] = arg
+ arg_class = arg.__class__
+ # If arg_value_hints specifies any name, use that instead.
+ if name not in arg_types:
+ arg_types[name] = (arg_class.__name__, arg_class)
+ if name == 'self' and tf_inspect.isclass(arg_class):
+ # Annotated methods need to specify that their owner type is partial,
+ # otherwise other members they call will not be converted.
+ partial_types = (arg_class,)
+ wrapped = to_graph(
+ f,
+ recursive=recursive,
+ arg_values=arg_values,
+ arg_types=arg_types,
+ partial_types=partial_types)
+ return wrapped(*args, **kwargs)
+
+ # Sometimes the decorator is just desugared, making it impossible to detect.
+ # This attribute makes detection easier.
+ setattr(wrapper, '__pyct_is_compile_decorator', True)
+ return wrapper
+
+ return decorator
+
+
+def to_graph(e,
+ recursive=True,
+ arg_values=None,
+ arg_types=None,
+ partial_types=None):
"""Compile a Python entity into equivalent TensorFlow code.
Currently supported entities:
@@ -42,16 +152,26 @@ def to_graph(o, arg_value_hints=None):
Classes are handled by converting all their methods into a new class.
Args:
- o: A Python function or class.
- arg_value_hints: A dict mapping parameter names to objects that can hint
- at the type of those parameters.
+ e: A Python entity.
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
+ arg_values: A dict containing value hints for symbols like function
+ parameters.
+ arg_types: A dict containing type hints for symbols like function
+ parameters.
+ partial_types: A set of types (e.g. classes) that will not be converted
+ entirely. Calls to member functions for these types will be renamed
+ independently.
Returns:
A function with a signature identical to `o`, but which when executed it
creates TF a graph that has the same functionality as the original entity.
"""
- conversion_map = conversion.ConversionMap()
- _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)
+ conversion_map = conversion.ConversionMap(
+ recursive=recursive,
+ nocompile_decorators=(convert, graph_ready, convert_inline),
+ partial_types=partial_types)
+ _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)
module = gast.Module([])
for import_line in config.COMPILED_IMPORT_STATEMENTS:
@@ -62,29 +182,39 @@ def to_graph(o, arg_value_hints=None):
# The compiled code should see everything the entry function saw.
# TODO(mdan): This might not work well if the call tree spans modules?
- if tf_inspect.isfunction(o):
- compiled_node.__dict__.update(six.get_function_globals(o))
+ if tf_inspect.isfunction(e):
+ compiled_node.__dict__.update(six.get_function_globals(e))
compiled_fn = getattr(compiled_node, name)
return compiled_fn
-def to_code(o, arg_value_hints=None, indentation=' '):
+def to_code(e,
+ recursive=True,
+ arg_values=None,
+ arg_types=None,
+ partial_types=None,
+ indentation=' '):
"""Return the equivalent of an entity in TensorFlow code.
See `to_graph` for more details.
Args:
- o: A Python function or class.
- arg_value_hints: A dict mapping parameter names to objects that can hint
- at the type of those parameters.
+ e: A Python entity.
+ recursive: See to_graph.
+ arg_values: See to_graph.
+ arg_types: See to_graph.
+ partial_types: See to_graph.
indentation: String, when to use for each level of indentation.
Returns:
String.
"""
- conversion_map = conversion.ConversionMap()
- conversion.object_to_graph(o, conversion_map, arg_value_hints)
+ conversion_map = conversion.ConversionMap(
+ recursive=recursive,
+ nocompile_decorators=(convert, graph_ready, convert_inline),
+ partial_types=partial_types)
+ conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)
imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
code = '\n'.join(
diff --git a/tensorflow/contrib/py2tf/api_test.py b/tensorflow/contrib/py2tf/api_test.py
index 225b6d305f..2384447708 100644
--- a/tensorflow/contrib/py2tf/api_test.py
+++ b/tensorflow/contrib/py2tf/api_test.py
@@ -28,17 +28,146 @@ from tensorflow.python.platform import test
class ApiTest(test.TestCase):
+ def setUp(self):
+ config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
+ config.COMPILED_IMPORT_STATEMENTS = (
+ 'from tensorflow.python.ops '
+ 'import control_flow_ops as tf',)
+
+ def test_decorator_recurses(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_does_not_recurse(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=False)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_converted(self):
+
+ class TestClass(object):
+
+ @api.graph_ready
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_decorated(self):
+
+ class TestClass(object):
+
+ @api.convert()
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_convert_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= api.convert_inline(self.called_member, a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_graph_ready_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return math_ops.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while math_ops.reduce_sum(x) > s:
+ x //= api.graph_ready(self.called_member(a))
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
def test_to_graph_basic(self):
def test_fn(x, s):
while math_ops.reduce_sum(x) > s:
x //= 2
return x
- config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
- config.COMPILED_IMPORT_STATEMENTS = (
- 'from tensorflow.python.ops '
- 'import control_flow_ops as tf',
- )
compiled_fn = api.to_graph(test_fn)
with self.test_session() as sess:
@@ -51,7 +180,6 @@ class ApiTest(test.TestCase):
x /= 2
return x
- config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,))
compiled_code = api.to_code(test_fn)
# Just check for some key words and that it is parseable Python code.
diff --git a/tensorflow/contrib/py2tf/config.py b/tensorflow/contrib/py2tf/config.py
index 0a9d52136e..8c502a7a9e 100644
--- a/tensorflow/contrib/py2tf/config.py
+++ b/tensorflow/contrib/py2tf/config.py
@@ -22,6 +22,7 @@ PYTHON_LITERALS = {
'None': None,
'False': False,
'True': True,
+ 'float': float,
}
DEFAULT_UNCOMPILED_MODULES = set((
diff --git a/tensorflow/contrib/py2tf/conversion.py b/tensorflow/contrib/py2tf/conversion.py
index 43bccae953..b484eebbd5 100644
--- a/tensorflow/contrib/py2tf/conversion.py
+++ b/tensorflow/contrib/py2tf/conversion.py
@@ -23,15 +23,17 @@ import six
from tensorflow.contrib.py2tf import config
from tensorflow.contrib.py2tf import naming
-from tensorflow.contrib.py2tf.convert import break_canonicalization
-from tensorflow.contrib.py2tf.convert import builtin_functions
-from tensorflow.contrib.py2tf.convert import call_trees
-from tensorflow.contrib.py2tf.convert import continue_canonicalization
-from tensorflow.contrib.py2tf.convert import control_flow
-from tensorflow.contrib.py2tf.convert import for_canonicalization
-from tensorflow.contrib.py2tf.convert import logical_expressions
-from tensorflow.contrib.py2tf.convert import print_functions
-from tensorflow.contrib.py2tf.convert import side_effect_guards
+from tensorflow.contrib.py2tf.converters import break_canonicalization
+from tensorflow.contrib.py2tf.converters import builtin_functions
+from tensorflow.contrib.py2tf.converters import call_trees
+from tensorflow.contrib.py2tf.converters import continue_canonicalization
+from tensorflow.contrib.py2tf.converters import control_flow
+from tensorflow.contrib.py2tf.converters import decorators
+from tensorflow.contrib.py2tf.converters import for_canonicalization
+from tensorflow.contrib.py2tf.converters import logical_expressions
+from tensorflow.contrib.py2tf.converters import print_functions
+from tensorflow.contrib.py2tf.converters import side_effect_guards
+from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
@@ -39,22 +41,35 @@ from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.util import tf_inspect
+# TODO(mdan): Might we not need any renaming at all?
+
+
class ConversionMap(object):
"""ConversionMaps keep track of converting function hierarchies.
Attributes:
- dependency_cache: dict[object]: ast; maps original objects to their
+ recursive: Whether to recusrively convert any functions that the decorator
+ function may call.
+ nocompile_decorators: tuple of decorator functions that toggle compilation
+ off.
+ dependency_cache: dict[object]: ast; maps original entities to their
converted AST
- name_map: dict[string]: string; maps original objects to the name of
+ name_map: dict[string]: string; maps original entities to the name of
their converted counterparts
"""
- def __init__(self):
+ # TODO(mdan): Rename to ConversionContext, and pull in additional flags.
+
+ def __init__(self, recursive, nocompile_decorators, partial_types):
+ self.recursive = recursive
+ self.nocompile_decorators = nocompile_decorators
+ self.partial_types = partial_types if partial_types else ()
self.dependency_cache = {}
self.name_map = {}
- def new_namer(self, global_symbols):
- return naming.Namer(global_symbols, self.name_map)
+ def new_namer(self, namespace):
+ return naming.Namer(namespace, self.recursive, self.name_map,
+ self.partial_types)
def update_name_map(self, namer):
for o, name in namer.renamed_calls.items():
@@ -62,77 +77,81 @@ class ConversionMap(object):
if self.name_map[o] != name:
raise ValueError(
'Calls to %s were converted using multiple names (%s). This is '
- 'possible when an object with one of these names already '
+ 'possible when an entity with one of these names already '
'existed. To fix, avoid using any of these names.')
else:
self.name_map[o] = name
- def add_to_cache(self, original_object, converted_ast):
- self.dependency_cache[original_object] = converted_ast
+ def add_to_cache(self, original_entity, converted_ast):
+ self.dependency_cache[original_entity] = converted_ast
-def object_to_graph(o, conversion_map, value_hints):
- """Compile a Python object into equivalent TensorFlow.
+def entity_to_graph(o, conversion_map, arg_values, arg_types):
+ """Compile a Python entity into equivalent TensorFlow.
- The function will also recursively compile all the objects that `o`
+ The function will also recursively compile all the entities that `o`
references, updating `dependency_cache`.
This function is reentrant, and relies on dependency_cache to avoid
generating duplicate code.
Args:
- o: A Python object.
+ o: A Python entity.
conversion_map: A ConversionMap object.
- value_hints: A dict containing value hints for symbols like function
+ arg_values: A dict containing value hints for symbols like function
+ parameters.
+ arg_types: A dict containing type hints for symbols like function
parameters.
Returns:
A tuple (ast, new_name):
- * ast: An AST representing an object with interface equivalent to `o`,
+ * ast: An AST representing an entity with interface equivalent to `o`,
but which when executed it creates TF a graph.
- * new_name: The symbol name under which the new object can be found.
+ * new_name: The symbol name under which the new entity can be found.
Raises:
- ValueError: if the object is not supported.
+ ValueError: if the entity type is not supported.
"""
- if value_hints is None:
- value_hints = {}
-
if tf_inspect.isclass(o):
- node, new_name = class_to_graph(o, conversion_map, value_hints)
+ node, new_name = class_to_graph(o, conversion_map)
elif tf_inspect.isfunction(o):
- node, new_name = function_to_graph(o, conversion_map, value_hints)
+ node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types)
+ elif tf_inspect.ismethod(o):
+ node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types)
else:
raise ValueError(
- 'Unsupported object type %s. Only functions and classes are supported'
- ' for now.')
+ 'Entity "%s" has unsupported type "%s". Only functions and classes are '
+ 'supported for now.' % (o, type(o)))
conversion_map.add_to_cache(o, node)
- # Recursively convert remaining dependencies.
- for obj in conversion_map.name_map.keys():
- if obj not in conversion_map.dependency_cache:
- if hasattr(obj, 'im_class'):
- # Class members are converted with their objects.
- continue
- object_to_graph(obj, conversion_map, None)
+ if conversion_map.recursive:
+ for obj in conversion_map.name_map.keys():
+ if obj not in conversion_map.dependency_cache:
+ if (hasattr(obj, 'im_class') and
+ getattr(obj, 'im_class') not in conversion_map.partial_types):
+ # Class members are converted with their objects, unless they're
+ # only converted partially.
+ continue
+ entity_to_graph(obj, conversion_map, {}, {})
return node, new_name
-def class_to_graph(c, conversion_map, param_value_hints):
- """Specialization of `object_to_graph` for classes."""
+def class_to_graph(c, conversion_map):
+ """Specialization of `entity_to_graph` for classes."""
converted_members = {}
members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
if not members:
raise ValueError('Cannot convert %s: it has no member methods.')
- if 'self' in param_value_hints:
- raise ValueError('Hints may not be provided for reserved name "self".')
- param_value_hints['self'] = (c.__name__, c)
-
class_globals = None
for _, m in members:
- node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
+ node, _ = function_to_graph(
+ m,
+ conversion_map=conversion_map,
+ arg_values={},
+ arg_types={'self': (c.__name__, c)},
+ owner_type=c)
# TODO(mdan): Do not assume all members have the same view of globals.
if class_globals is None:
class_globals = six.get_function_globals(m)
@@ -149,10 +168,11 @@ def class_to_graph(c, conversion_map, param_value_hints):
return node, class_name
-def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
- """Specialization of `object_to_graph` for callable functions."""
+def function_to_graph(f, conversion_map, arg_values, arg_types,
+ owner_type=None):
+ """Specialization of `entity_to_graph` for callable functions."""
node = parser.parse_object(f).body[0]
- node_globals = six.get_function_globals(f)
+ namespace = six.get_function_globals(f)
# This is needed for non-global functions.
closure = six.get_function_closure(f)
@@ -160,10 +180,17 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
for e in closure:
if callable(e.cell_contents):
fn = e.cell_contents
- node_globals[fn.__name__] = fn
-
- namer = conversion_map.new_namer(node_globals)
- node = node_to_graph(node, namer, node_globals, param_value_hints)
+ namespace[fn.__name__] = fn
+
+ namer = conversion_map.new_namer(namespace)
+ ctx = context.EntityContext(
+ namer=namer,
+ source_code=tf_inspect.getsource(f),
+ source_file=tf_inspect.getfile(f),
+ namespace=namespace,
+ arg_values=arg_values,
+ arg_types=arg_types)
+ node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
# Simulate a rename to ensure the top level is in the name map. This is needed
# for top level functions, and it also helps the consistency verification made
@@ -177,29 +204,30 @@ def function_to_graph(f, conversion_map, param_value_hints, owner_type=None):
return node, conversion_map.name_map[f]
-def _static_analysis_pass(node, namespace, value_hints):
+def _static_analysis_pass(node, ctx):
node = access.resolve(node)
- node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
- node = type_info.resolve(node, value_hints)
+ node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS)
+ node = type_info.resolve(node, ctx)
return node
-def node_to_graph(node, namer, namespace, value_hints):
+def node_to_graph(node, ctx, nocompile_decorators):
"""Convert Python code to equivalent TF graph mode code.
Args:
node: A Python AST node representing the code to convert.
- namer: A naming.Namer object.
- namespace: Dict mapping symbol names to their corresponding live objects.
- value_hints: A dict containing value hints for symbols like function
- parameters.
+ ctx: An EntityContext object.
+ nocompile_decorators: A tuple containing decorators to be stripped from
+ functions during conversion.
Returns:
A tuple (node, deps):
* node: A Python ast node, representing the converted code.
- * deps: A set of strings, the fully qualified names of object
+ * deps: A set of strings, the fully qualified names of entity
dependencies that this node has.
"""
+ # TODO(mdan): Verify arguments for correctness.
+
# TODO(mdan): Factor out common elements.
# These include:
# * keeping track of symbols that have been created
@@ -212,27 +240,30 @@ def node_to_graph(node, namer, namespace, value_hints):
# tree, which must be accounted. Although less efficient, it is most robust
# to re-run the analysis.
- node = _static_analysis_pass(node, namespace, value_hints)
- node = break_canonicalization.transform(node, namer)
+ node = _static_analysis_pass(node, ctx)
+ node = decorators.transform(node, nocompile_decorators)
+ node = break_canonicalization.transform(node, ctx.namer)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
- node = continue_canonicalization.transform(node, namer)
- namespace['len'] = len
+ node = continue_canonicalization.transform(node, ctx.namer)
+ ctx.namespace['len'] = len
- node = _static_analysis_pass(node, namespace, value_hints)
- node = for_canonicalization.transform(node, namer)
+ node = _static_analysis_pass(node, ctx)
+ node = for_canonicalization.transform(node, ctx.namer)
# for_canonicalization may insert new global references.
node = builtin_functions.transform(node)
# builtin_functions may insert new global references.
- namespace['print'] = print
+ ctx.namespace['print'] = print
- node = _static_analysis_pass(node, namespace, value_hints)
+ node = _static_analysis_pass(node, ctx)
node = print_functions.transform(node)
- node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
- node = control_flow.transform(node, namer)
+ node = call_trees.transform(node, ctx.namer, ctx.namespace,
+ config.DEFAULT_UNCOMPILED_MODULES,
+ nocompile_decorators)
+ node = control_flow.transform(node, ctx.namer)
node = logical_expressions.transform(node)
- node = side_effect_guards.transform(node, namer)
+ node = side_effect_guards.transform(node, ctx.namer)
return node
diff --git a/tensorflow/contrib/py2tf/conversion_test.py b/tensorflow/contrib/py2tf/conversion_test.py
index d76f141809..26f915f4f4 100644
--- a/tensorflow/contrib/py2tf/conversion_test.py
+++ b/tensorflow/contrib/py2tf/conversion_test.py
@@ -26,28 +26,31 @@ from tensorflow.python.platform import test
class ConversionTest(test.TestCase):
- def test_object_to_graph_unsupported_types(self):
+ def test_entity_to_graph_unsupported_types(self):
with self.assertRaises(ValueError):
- conversion.object_to_graph('dummy', {}, {})
+ conversion_map = conversion.ConversionMap(True, (), ())
+ conversion.entity_to_graph('dummy', conversion_map, None, None)
+
+ def test_entity_to_graph_callable(self):
- def test_object_to_graph_callable(self):
def f(a):
return a
- conversion_map = conversion.ConversionMap()
- ast, new_name = conversion.object_to_graph(f, conversion_map, {})
+ conversion_map = conversion.ConversionMap(True, (), ())
+ ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None)
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
self.assertEqual('tf__f', new_name)
- def test_object_to_graph_call_tree(self):
+ def test_entity_to_graph_call_tree(self):
+
def g(a):
return a
def f(a):
return g(a)
- conversion_map = conversion.ConversionMap()
- conversion.object_to_graph(f, conversion_map, {})
+ conversion_map = conversion.ConversionMap(True, (), ())
+ conversion.entity_to_graph(f, conversion_map, None, None)
self.assertTrue(f in conversion_map.dependency_cache)
self.assertTrue(g in conversion_map.dependency_cache)
diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/converters/BUILD
index 0eb7998dc4..2b0a1234e6 100644
--- a/tensorflow/contrib/py2tf/convert/BUILD
+++ b/tensorflow/contrib/py2tf/converters/BUILD
@@ -15,13 +15,14 @@ filegroup(
)
py_library(
- name = "convert",
+ name = "converters",
srcs = [
"break_canonicalization.py",
"builtin_functions.py",
"call_trees.py",
"continue_canonicalization.py",
"control_flow.py",
+ "decorators.py",
"for_canonicalization.py",
"logical_expressions.py",
"print_functions.py",
@@ -34,13 +35,26 @@ py_library(
],
)
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":converters",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "@gast_archive//:gast",
+ ],
+)
+
py_test(
name = "break_canonicalization_test",
srcs = ["break_canonicalization_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -49,9 +63,8 @@ py_test(
name = "call_trees_test",
srcs = ["call_trees_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -60,9 +73,8 @@ py_test(
name = "continue_canonicalization_test",
srcs = ["continue_canonicalization_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -71,9 +83,8 @@ py_test(
name = "control_flow_test",
srcs = ["control_flow_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -82,9 +93,8 @@ py_test(
name = "builtin_functions_test",
srcs = ["builtin_functions_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -93,9 +103,8 @@ py_test(
name = "for_canonicalization_test",
srcs = ["for_canonicalization_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -104,9 +113,8 @@ py_test(
name = "logical_expressions_test",
srcs = ["logical_expressions_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
@@ -115,9 +123,8 @@ py_test(
name = "print_functions_test",
srcs = ["print_functions_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
"@gast_archive//:gast",
],
@@ -127,9 +134,8 @@ py_test(
name = "side_effect_guards_test",
srcs = ["side_effect_guards_test.py"],
deps = [
- ":convert",
+ ":test_lib",
"//tensorflow/contrib/py2tf/pyct",
- "//tensorflow/contrib/py2tf/pyct/static_analysis",
"//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/contrib/py2tf/convert/__init__.py b/tensorflow/contrib/py2tf/converters/__init__.py
index ca10896ee5..ca10896ee5 100644
--- a/tensorflow/contrib/py2tf/convert/__init__.py
+++ b/tensorflow/contrib/py2tf/converters/__init__.py
diff --git a/tensorflow/contrib/py2tf/convert/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_canonicalization.py
index ef58573445..2ae65e3007 100644
--- a/tensorflow/contrib/py2tf/convert/break_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/break_canonicalization.py
@@ -33,31 +33,25 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
self.break_uses = []
def _create_break_check(self):
-
- def template(var_name):
- (not var_name) # pylint:disable=pointless-statement
-
- expr, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ template = """
+ (not var_name)
+ """
+ expr, = templates.replace(template, var_name=self.break_uses[-1][1])
return expr.value
def _create_break_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
- block = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ block = templates.replace(template, var_name=self.break_uses[-1][1])
block.append(gast.Continue())
return block
def _create_break_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
- assign, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ assign, = templates.replace(template, var_name=self.break_uses[-1][1])
return assign
# TODO(mdan): Surely the transformer supports this better?
diff --git a/tensorflow/contrib/py2tf/convert/break_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py
index 23c4c4d3e2..b5ba2ad923 100644
--- a/tensorflow/contrib/py2tf/convert/break_canonicalization_test.py
+++ b/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import break_canonicalization
-from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.converters import break_canonicalization
+from tensorflow.contrib.py2tf.converters import control_flow
+from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.python.platform import test
@@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer):
return name_root
-class BreakCanonicalizationTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- return node
+class BreakCanonicalizationTest(converter_test_base.TestCase):
def test_basic_break(self):
@@ -50,7 +44,7 @@ class BreakCanonicalizationTest(test.TestCase):
v.append(x)
return v
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = break_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
@@ -82,7 +76,7 @@ class BreakCanonicalizationTest(test.TestCase):
v.append(x)
return v
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = break_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
@@ -110,7 +104,7 @@ class BreakCanonicalizationTest(test.TestCase):
v.append(x)
return v, u, w
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = break_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py
index b80c96c97a..7f6b64a34c 100644
--- a/tensorflow/contrib/py2tf/convert/builtin_functions.py
+++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py
@@ -29,10 +29,9 @@ class BuiltinFunctionTransformer(gast.NodeTransformer):
# TODO(mdan): Bring print_functions in here.
def _convert_len(self, node):
-
- def template(args):
- tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned
-
+ template = """
+ tf.shape(args)[0]
+ """
new_call = templates.replace(template, args=node.args)[0].value
return new_call
diff --git a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py
index 633602f4d4..b5358da6bc 100644
--- a/tensorflow/contrib/py2tf/convert/builtin_functions_test.py
+++ b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py
@@ -18,32 +18,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import builtin_functions
+from tensorflow.contrib.py2tf.converters import builtin_functions
+from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
-from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class BuiltinFunctionsTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, {})
- return node
+class BuiltinFunctionsTest(converter_test_base.TestCase):
def test_len(self):
def test_fn(a):
return len(a)
- node = self._parse_and_analyze(test_fn, {'len': len})
+ node = self.parse_and_analyze(test_fn, {'len': len})
node = builtin_functions.transform(node)
result = compiler.ast_to_object(node)
setattr(result, 'tf', array_ops)
diff --git a/tensorflow/contrib/py2tf/convert/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py
index 92c3439101..0aae030450 100644
--- a/tensorflow/contrib/py2tf/convert/call_trees.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees.py
@@ -27,6 +27,7 @@ import types
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import templates
@@ -64,16 +65,75 @@ class FunctionNamer(object):
class CallTreeTransformer(gast.NodeTransformer):
"""Transforms the call tree by renaming transformed symbols."""
- def __init__(self, namer, uncompiled_modules):
+ def __init__(self, namer, namespace, uncompiled_modules,
+ nocompile_decorators):
self.namer = namer
+ self.namespace = namespace
self.uncompiled_modules = uncompiled_modules
+ self.nocompile_decorators = nocompile_decorators
# pylint:disable=invalid-name
- def _should_compile(self, fqn):
+ def _resolve_name(self, node):
+ if isinstance(node, gast.Call):
+ return self._resolve_name(node.func)
+ if isinstance(node, gast.Name):
+ return self.namespace.get(node.id)
+ if isinstance(node, gast.Attribute):
+ parent = self._resolve_name(node.value)
+ if parent is not None:
+ return getattr(parent, node.attr)
+ return None
+ raise ValueError(node)
+
+ def _try_resolve_target(self, node):
+ """Works for methods of objects of known type."""
+ if anno.hasanno(node, 'live_val'):
+ return anno.getanno(node, 'live_val')
+ if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
+ member = getattr(anno.getanno(node, 'type'), node.attr)
+ return member
+ return None
+
+ def _should_compile(self, node, fqn):
for i in range(1, len(fqn)):
if fqn[:i] in self.uncompiled_modules:
return False
+
+ # Check for local decorations
+ if anno.hasanno(node, 'graph_ready'):
+ return False
+
+ # The decorators themselves are not to be converted.
+ # If present, the decorators should appear as static functions.
+ target_obj = self._try_resolve_target(node.func)
+ if target_obj is not None:
+ # This attribute is set by the decorator itself.
+ # TODO(mdan): This may not play nicely with other wrapping decorators.
+ if hasattr(target_obj, '__pyct_is_compile_decorator'):
+ return False
+
+ if target_obj in self.nocompile_decorators:
+ return False
+
+ # Inspect the target function decorators. If any include a @convert
+ # or @graph_ready annotation, then they must be called as they are.
+ # TODO(mdan): This may be quite heavy.
+ # To parse and re-analize each function for every call site could be quite
+ # wasteful. Maybe we could cache the parsed AST?
+ try:
+ target_node = parser.parse_object(target_obj).body[0]
+ except TypeError:
+ # Functions whose source we cannot access are compilable (e.g. wrapped
+ # to py_func).
+ return True
+
+ for dec in target_node.decorator_list:
+ decorator_fn = self._resolve_name(dec)
+ if (decorator_fn is not None and
+ decorator_fn in self.nocompile_decorators):
+ return False
+
return True
def _rename_compilable_function(self, node):
@@ -82,16 +142,16 @@ class CallTreeTransformer(gast.NodeTransformer):
target_obj = anno.getanno(node.func, 'live_val')
target_fqn = anno.getanno(node.func, 'fqn')
- if not self._should_compile(target_fqn):
+ if not self._should_compile(node, target_fqn):
return node
if anno.hasanno(node, 'is_constructor'):
new_name = self.namer.compiled_class_name(
- '.'.join(target_fqn), live_object=target_obj)
+ '__'.join(target_fqn), live_object=target_obj)
else:
new_name = self.namer.compiled_function_name(
- '.'.join(target_fqn), live_object=target_obj)
- node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+ '__'.join(target_fqn), live_object=target_obj)
+ node.func = gast.Name(new_name, gast.Load(), None)
return node
def _rename_member_function_of_known_type(self, node):
@@ -101,41 +161,42 @@ class CallTreeTransformer(gast.NodeTransformer):
assert anno.hasanno(node.func, 'type')
target_type = anno.getanno(node.func, 'type')
- if not self._should_compile(type_fqn):
+ if not self._should_compile(node, type_fqn):
return node
# TODO(mdan): We should not assume that the namer only needs the
# member function name.
+ method_name = node.func.attr
+ method_object = getattr(target_type, method_name)
new_name = self.namer.compiled_function_name(
- node.func.attr, live_object=None, owner_type=target_type)
- node.func.attr = new_name
-
+ method_name, live_object=method_object, owner_type=target_type)
+ if new_name != node.func.attr:
+ # If a member function call is renamed, then the new function is no
+ # longer bound to the target object. We then refactor the call from:
+ # foo.bar(...)
+ # to:
+ # renamed_foo(bar, ...)
+ # TODO(mdan): This risks causing duplication, if target_type is renamed.
+ node.args = [node.func.value] + node.args
+ node.func = gast.Name(new_name, gast.Load(), None)
return node
def _wrap_to_py_func_no_return(self, node):
args_scope = anno.getanno(node, 'args_scope')
# TODO(mdan): Properly handle varargs, kwargs, etc.
- args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
-
- # pylint:disable=undefined-variable,unused-argument,function-redefined
-
- def template(call, wrapper, args):
-
+ template = """
def wrapper(args):
call(args)
return 1
-
tf.py_func(wrapper, [args], [tf.int64])
-
- # pylint:enable=undefined-variable,unused-argument,function-redefined
-
- wrapper_name = self.namer.compiled_function_name(node.func.id)
+ """
wrapper_def, call_expr = templates.replace(
template,
call=node.func,
- wrapper=gast.Name(wrapper_name, gast.Load(), None),
- args=args)
+ wrapper=self.namer.compiled_function_name(node.func.id),
+ args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
anno.setanno(call_expr.value, 'args_scope', args_scope)
+ # TODO(mdan): Rename this annotation to 'graph_ready'
anno.setanno(wrapper_def, 'skip_processing', True)
return (wrapper_def, call_expr)
@@ -151,7 +212,7 @@ class CallTreeTransformer(gast.NodeTransformer):
if not self._function_is_compilable(target_obj):
if anno.hasanno(node.value.func, 'fqn'):
target_fqn = anno.getanno(node.value.func, 'fqn')
- if not self._should_compile(target_fqn):
+ if not self._should_compile(node.value, target_fqn):
return node
node = self._wrap_to_py_func_no_return(node.value)
return node
@@ -163,6 +224,17 @@ class CallTreeTransformer(gast.NodeTransformer):
return node
def visit_Call(self, node):
+ # If the function is wrapped by one of the marker decorators,
+ # consider it graph ready.
+ if anno.hasanno(node.func, 'live_val'):
+ target_obj = anno.getanno(node.func, 'live_val')
+ if target_obj in self.nocompile_decorators:
+ if len(node.args) < 1:
+ raise ValueError(
+ 'Found call to decorator function "%s", but it had no arguments. '
+ 'A decorator needs at least an argument.')
+ anno.setanno(node.args[0], 'graph_ready', True)
+
self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
target_obj = anno.getanno(node.func, 'live_val')
@@ -180,20 +252,24 @@ class CallTreeTransformer(gast.NodeTransformer):
# pylint:enable=invalid-name
-def transform(node, namer, uncompiled_modules):
+def transform(node, namer, namespace, uncompiled_modules, nocompile_decorators):
"""Transform function call to the compiled counterparts.
Args:
node: AST to transform.
namer: FunctionNamer-like.
+ namespace: Dict mapping symbol names to their corresponding live objects.
uncompiled_modules: set of string tuples, each tuple represents the fully
qualified name of a package containing functions that will not be
compiled.
+ nocompile_decorators: A tuple containing decorators to be stripped from
+ functions during conversion.
Returns:
A tuple (node, new_names):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
- transformer = CallTreeTransformer(namer, uncompiled_modules)
+ transformer = CallTreeTransformer(namer, namespace, uncompiled_modules,
+ nocompile_decorators)
node = transformer.visit(node)
return node
diff --git a/tensorflow/contrib/py2tf/convert/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py
index 38c701eaad..8cb8d7be0f 100644
--- a/tensorflow/contrib/py2tf/convert/call_trees_test.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py
@@ -18,12 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import call_trees
+from tensorflow.contrib.py2tf.converters import call_trees
+from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
-from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -35,14 +32,7 @@ class TestNamer(call_trees.FunctionNamer):
return 'renamed_%s' % original_name
-class CallTreesTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, {})
- return node
+class CallTreesTest(converter_test_base.TestCase):
def test_basic(self):
@@ -55,8 +45,8 @@ class CallTreesTest(test.TestCase):
def test_fn_2(a):
return test_fn_1(a) + 1
- node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
- node = call_trees.transform(node, TestNamer(), set())
+ node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
+ node = call_trees.transform(node, TestNamer(), {}, (), ())
result = compiler.ast_to_object(node)
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
@@ -70,13 +60,13 @@ class CallTreesTest(test.TestCase):
a = math_ops.add(a, constant_op.constant(1))
return a
- node = self._parse_and_analyze(test_fn, {
+ node = self.parse_and_analyze(test_fn, {
'math_ops': math_ops,
'constant_op': constant_op
})
- node = call_trees.transform(node, TestNamer(),
+ node = call_trees.transform(node, TestNamer(), {},
set(((math_ops.__name__,),
- (constant_op.__name__,))))
+ (constant_op.__name__,))), ())
result = compiler.ast_to_object(node)
setattr(result, 'math_ops', math_ops)
setattr(result, 'constant_op', constant_op)
diff --git a/tensorflow/contrib/py2tf/convert/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py
index 7f8ace77a8..486f0f6509 100644
--- a/tensorflow/contrib/py2tf/convert/continue_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py
@@ -33,32 +33,28 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
self.continuation_uses = []
def _create_continuation_check(self):
-
- def template(var_name):
+ template = """
if not var_name:
pass
-
- cond, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ """
+ cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
cond.body = []
return cond
def _create_continuation_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _create_continuation_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _visit_and_reindent_if_necessary(self, nodes):
diff --git a/tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py
index a041ff4641..c1fe903a2d 100644
--- a/tensorflow/contrib/py2tf/convert/continue_canonicalization_test.py
+++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import continue_canonicalization
-from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.converters import continue_canonicalization
+from tensorflow.contrib.py2tf.converters import control_flow
+from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.python.platform import test
@@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer):
return name_root
-class ContinueCanonicalizationTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- return node
+class ContinueCanonicalizationTest(converter_test_base.TestCase):
def test_basic_continue(self):
@@ -50,7 +44,7 @@ class ContinueCanonicalizationTest(test.TestCase):
v.append(x)
return v
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = continue_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
@@ -71,7 +65,7 @@ class ContinueCanonicalizationTest(test.TestCase):
v.append(x)
return v
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = continue_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
@@ -97,7 +91,7 @@ class ContinueCanonicalizationTest(test.TestCase):
v.append(x)
return v, u, w
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {}, include_type_analysis=False)
node = continue_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/converters/control_flow.py
index 8ebd9ad93d..a40c7b28f7 100644
--- a/tensorflow/contrib/py2tf/convert/control_flow.py
+++ b/tensorflow/contrib/py2tf/converters/control_flow.py
@@ -75,29 +75,6 @@ class ControlFlowTransformer(gast.NodeTransformer):
raise ValueError(
'The else branch creates new symbols that the if branch does not.')
- def template( # pylint:disable=missing-docstring
- test,
- body_name,
- body,
- orelse_name,
- orelse,
- aliased,
- aliases, # pylint:disable=unused-argument
- aliased_results,
- results): # pylint:disable=unused-argument
-
- def body_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- body # pylint:disable=pointless-statement
- return (aliased_results,)
-
- def orelse_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- orelse # pylint:disable=pointless-statement
- return (aliased_results,)
-
- results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable
-
all_modified = tuple(body_scope.modified | orelse_scope.modified)
all_referenced = body_scope.referenced | orelse_scope.referenced
@@ -107,10 +84,10 @@ class ControlFlowTransformer(gast.NodeTransformer):
need_alias = (
(body_scope.modified | orelse_scope.modified) -
(body_scope.created | orelse_scope.created))
- aliased = tuple(need_alias)
- aliases = tuple(
- self.namer.new_symbol(s, all_referenced) for s in aliased)
- alias_map = dict(zip(aliased, aliases))
+ aliased_orig_names = tuple(need_alias)
+ aliased_new_names = tuple(
+ self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+ alias_map = dict(zip(aliased_orig_names, aliased_new_names))
node_body = node.body
node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
node_orelse = node.orelse
@@ -122,20 +99,29 @@ class ControlFlowTransformer(gast.NodeTransformer):
results = gast.Tuple(
tuple(gast.Name(s, None, None) for s in all_modified), None)
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (all_results,)
+ def orelse_name():
+ aliased_new_names, = aliased_orig_names,
+ orelse
+ return (all_results,)
+ results = tf.cond(test, body_name, orelse_name)
+ """
+ body_name = self.namer.new_symbol('if_true', all_referenced)
return templates.replace(
template,
test=node.test,
- body_name=gast.Name(
- self.namer.new_symbol('if_true', all_referenced), None, None),
+ body_name=body_name,
body=node_body,
- orelse_name=gast.Name(
- self.namer.new_symbol('if_false', all_referenced), None, None),
+ orelse_name=self.namer.new_symbol('if_false', all_referenced),
orelse=node_orelse,
- aliased=tuple(gast.Name(s, None, None) for s in aliased),
- aliases=tuple(gast.Name(s, None, None) for s in aliases),
- aliased_results=tuple(
- gast.Name(alias_map[s] if s in aliased else s, None, None)
- for s in all_modified),
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+ for s in all_modified),
results=results)
def visit_While(self, node):
@@ -144,38 +130,28 @@ class ControlFlowTransformer(gast.NodeTransformer):
body_scope = anno.getanno(node, 'body_scope')
body_closure = tuple(body_scope.modified - body_scope.created)
- def template(
- state, # pylint:disable=unused-argument
- state_ast_tuple, # pylint:disable=unused-argument
- test_name,
- test, # pylint:disable=unused-argument
- body_name,
- body):
-
- def test_name(state): # pylint:disable=function-redefined,unused-argument
- return test
-
- def body_name(state): # pylint:disable=function-redefined,unused-argument
- body # pylint:disable=pointless-statement
- return state,
-
- state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable
-
- test_name = self.namer.new_symbol('loop_test', body_scope.referenced)
- body_name = self.namer.new_symbol('loop_body', body_scope.referenced)
if len(body_closure) == 1:
- state = gast.Name(body_closure[0], None, None)
+ state = body_closure[0]
state_ast_tuple = state
else:
- state = tuple(gast.Name(n, None, None) for n in body_closure)
- state_ast_tuple = gast.Tuple(state, None)
+ state = tuple(body_closure)
+ state_ast_tuple = gast.Tuple(
+ tuple(gast.Name(n, None, None) for n in state), None)
+ template = """
+ def test_name(state):
+ return test
+ def body_name(state):
+ body
+ return state,
+ state_ast_tuple = tf.while_loop(test_name, body_name, [state])
+ """
node = templates.replace(
template,
state=state,
state_ast_tuple=state_ast_tuple,
- test_name=gast.Name(test_name, gast.Load(), None),
+ test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
test=node.test,
- body_name=gast.Name(body_name, gast.Load(), None),
+ body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
body=node.body)
return node
diff --git a/tensorflow/contrib/py2tf/convert/control_flow_test.py b/tensorflow/contrib/py2tf/converters/control_flow_test.py
index 121af4ee94..054e33750d 100644
--- a/tensorflow/contrib/py2tf/convert/control_flow_test.py
+++ b/tensorflow/contrib/py2tf/converters/control_flow_test.py
@@ -18,12 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.converters import control_flow
+from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
-from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -40,14 +37,7 @@ class TestNamer(control_flow.SymbolNamer):
i += 1
-class ControlFlowTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, {})
- return node
+class ControlFlowTest(converter_test_base.TestCase):
def test_simple_while(self):
@@ -59,7 +49,7 @@ class ControlFlowTest(test.TestCase):
i += 1
return s, i, n
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
@@ -75,7 +65,7 @@ class ControlFlowTest(test.TestCase):
n -= 1
return n
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
@@ -94,7 +84,7 @@ class ControlFlowTest(test.TestCase):
b = 2 * n
return a, b
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
@@ -112,7 +102,7 @@ class ControlFlowTest(test.TestCase):
n = -n
return n
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'tf', control_flow_ops)
diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py
new file mode 100644
index 0000000000..ed006bad6d
--- /dev/null
+++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for tests in this module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.pyct import context
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.platform import test
+
+
+class TestCase(test.TestCase):
+
+ def parse_and_analyze(self,
+ test_fn,
+ namespace,
+ arg_types=None,
+ include_type_analysis=True):
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=None,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types)
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ if include_type_analysis:
+ node = type_info.resolve(node, ctx)
+ return node
diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/py2tf/converters/decorators.py
new file mode 100644
index 0000000000..a4313bfa51
--- /dev/null
+++ b/tensorflow/contrib/py2tf/converters/decorators.py
@@ -0,0 +1,56 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles decorators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import pretty_printer
+
+
+class DecoratorsTransformer(gast.NodeTransformer):
+ """Converts or removes decorators."""
+
+ def __init__(self, remove_decorators):
+ self.remove_decorators = remove_decorators
+
+ # pylint:disable=invalid-name
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ for dec in node.decorator_list:
+ if isinstance(dec, gast.Call):
+ dec = dec.func
+ if not anno.hasanno(dec, 'live_val'):
+ raise ValueError(
+ 'Could not resolve decorator: %s' % pretty_printer.fmt(dec))
+ dec_value = anno.getanno(dec, 'live_val')
+ if dec_value in self.remove_decorators:
+ continue
+ raise ValueError('Dont know how to convert decorators for now.')
+ node.decorator_list = []
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node, remove_decorators):
+ transformer = DecoratorsTransformer(remove_decorators)
+ node = transformer.visit(node)
+ return node
diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_canonicalization.py
index 52360789cd..c284689b90 100644
--- a/tensorflow/contrib/py2tf/convert/for_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/for_canonicalization.py
@@ -42,46 +42,40 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
# Or maybe we should replace range with tf.range?
if anno.hasanno(node, 'extra_cond'):
-
- def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n and extra_cond:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
- body # pylint:disable=pointless-statement
+ body
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None),
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced),
extra_cond=anno.getanno(node, 'extra_cond'))
else:
-
- def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
body # pylint:disable=pointless-statement
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None))
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced))
def visit_Continue(self, node):
assert False, 'continue statement should be desugared at this point'
diff --git a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py
index 8de2d1a0f8..a6e6350fd4 100644
--- a/tensorflow/contrib/py2tf/convert/for_canonicalization_test.py
+++ b/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import control_flow
-from tensorflow.contrib.py2tf.convert import for_canonicalization
+from tensorflow.contrib.py2tf.converters import control_flow
+from tensorflow.contrib.py2tf.converters import converter_test_base
+from tensorflow.contrib.py2tf.converters import for_canonicalization
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.python.platform import test
@@ -32,12 +31,7 @@ class TestNamer(control_flow.SymbolNamer):
return name_root
-class ControlFlowTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- return node
+class ControlFlowTest(converter_test_base.TestCase):
def test_basic_for(self):
@@ -47,7 +41,7 @@ class ControlFlowTest(test.TestCase):
s += e
return s
- node = self._parse_and_analyze(test_fn, {})
+ node = self.parse_and_analyze(test_fn, {})
node = for_canonicalization.transform(node, TestNamer())
result = compiler.ast_to_object(node)
diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions.py b/tensorflow/contrib/py2tf/converters/logical_expressions.py
index df980d41c9..df980d41c9 100644
--- a/tensorflow/contrib/py2tf/convert/logical_expressions.py
+++ b/tensorflow/contrib/py2tf/converters/logical_expressions.py
diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions_test.py b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py
index f07fa017b9..d711065099 100644
--- a/tensorflow/contrib/py2tf/convert/logical_expressions_test.py
+++ b/tensorflow/contrib/py2tf/converters/logical_expressions_test.py
@@ -18,21 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import logical_expressions
+from tensorflow.contrib.py2tf.converters import converter_test_base
+from tensorflow.contrib.py2tf.converters import logical_expressions
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class GradientsFunctionTest(test.TestCase):
+class GradientsFunctionTest(converter_test_base.TestCase):
def test_equals(self):
def test_fn(a, b):
return a == b
- node = parser.parse_object(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
node = logical_expressions.transform(node)
result = compiler.ast_to_object(node)
setattr(result, 'tf', math_ops)
@@ -46,7 +46,7 @@ class GradientsFunctionTest(test.TestCase):
def test_fn(a, b, c):
return (a or b) and (a or b or c)
- node = parser.parse_object(test_fn)
+ node = self.parse_and_analyze(test_fn, {})
node = logical_expressions.transform(node)
result = compiler.ast_to_object(node)
setattr(result, 'tf', math_ops)
diff --git a/tensorflow/contrib/py2tf/convert/print_functions.py b/tensorflow/contrib/py2tf/converters/print_functions.py
index 5da738c495..5da738c495 100644
--- a/tensorflow/contrib/py2tf/convert/print_functions.py
+++ b/tensorflow/contrib/py2tf/converters/print_functions.py
diff --git a/tensorflow/contrib/py2tf/convert/print_functions_test.py b/tensorflow/contrib/py2tf/converters/print_functions_test.py
index 65e592b66e..475196ce10 100644
--- a/tensorflow/contrib/py2tf/convert/print_functions_test.py
+++ b/tensorflow/contrib/py2tf/converters/print_functions_test.py
@@ -20,30 +20,20 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.py2tf.convert import print_functions
+from tensorflow.contrib.py2tf.converters import converter_test_base
+from tensorflow.contrib.py2tf.converters import print_functions
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
-from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.platform import test
-class PrintFunctionsTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, {})
- return node
+class PrintFunctionsTest(converter_test_base.TestCase):
def test_transform(self):
def test_fn(a):
print(a)
- node = self._parse_and_analyze(test_fn, {'print': print})
+ node = self.parse_and_analyze(test_fn, {'print': print})
node = print_functions.transform(node)
result = compiler.ast_to_object(node)
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
index 1f25303fba..4df723989d 100644
--- a/tensorflow/contrib/py2tf/convert/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
@@ -34,6 +34,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from contextlib import contextmanager
+
import gast
from tensorflow.contrib.py2tf.pyct import anno
@@ -94,12 +96,10 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
return node
def _gate_symbols(self, guard_statement, guarded_args):
-
- def template(args): # pylint:disable=unused-argument
- (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable
-
- guards = templates.replace(
- template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
+ template = """
+ (args,) = (tf.identity(a) for a in (args,))
+ """
+ guards = templates.replace(template, args=tuple(guarded_args))
guard_statement.body.extend(guards)
return guard_statement
@@ -110,29 +110,25 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
# opt.minimize(loss)
# or:
# tf.py_func(...)
-
args_scope = anno.getanno(node.value, 'args_scope')
temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
# TODO(mdan): Unsafe reference modification!
args_scope.mark_write(temp_name)
-
- def template(call, temp_result):
+ template = """
temp_result = call
if temp_result is not None:
if not isinstance(temp_result, (list, tuple)):
temp_result = (temp_result,)
- ctx = tf.control_dependencies(temp_result) # pylint:disable=undefined-variable
+ ctx = tf.control_dependencies(temp_result)
else:
- ctx = contextmanager(lambda: (yield))() # pylint:disable=undefined-variable
+ ctx = contextmanager(lambda: (yield))()
with ctx:
# TODO(mdan): Also insert ops to re-fetch if variables are involved.
pass # Will be removed below.
-
- # TODO(mdan): This is brittle. Reorganize this mechanism.
+ """
+ # TODO(mdan): This is brittle. Reorganize the mechanism.
statements = templates.replace(
- template,
- call=node.value,
- temp_result=gast.Name(temp_name, None, None))
+ template, call=node.value, temp_result=temp_name)
control_deps_guard = statements[-1]
control_deps_guard.body = []
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
index d932840186..5c56973dc2 100644
--- a/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py
@@ -18,12 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.py2tf.convert import side_effect_guards
+from tensorflow.contrib.py2tf.converters import converter_test_base
+from tensorflow.contrib.py2tf.converters import side_effect_guards
from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.contrib.py2tf.pyct.static_analysis import access
-from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
-from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
@@ -37,14 +34,7 @@ class TestNamer(side_effect_guards.SymbolNamer):
return name_root
-class SideEffectGuardsTest(test.TestCase):
-
- def _parse_and_analyze(self, test_fn, namespace):
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, namespace, {})
- node = type_info.resolve(node, {})
- return node
+class SideEffectGuardsTest(converter_test_base.TestCase):
def test_transform(self):
@@ -52,7 +42,7 @@ class SideEffectGuardsTest(test.TestCase):
state_ops.assign(a, a + 1)
return a
- node = self._parse_and_analyze(test_fn, {'state_ops': state_ops})
+ node = self.parse_and_analyze(test_fn, {'state_ops': state_ops})
node = side_effect_guards.transform(node, TestNamer())
result = compiler.ast_to_object(node)
setattr(result, 'state_ops', state_ops)
diff --git a/tensorflow/contrib/py2tf/naming.py b/tensorflow/contrib/py2tf/naming.py
index 61772ec07b..a90758962b 100644
--- a/tensorflow/contrib/py2tf/naming.py
+++ b/tensorflow/contrib/py2tf/naming.py
@@ -34,8 +34,10 @@ class Namer(object):
* side_effect_guards.SymbolNamer
"""
- def __init__(self, global_namespace, name_map=None):
+ def __init__(self, global_namespace, recursive, name_map, partial_types):
self.global_namespace = global_namespace
+ self.recursive = recursive
+ self.partial_types = partial_types
self.renamed_calls = {}
if name_map is not None:
@@ -54,6 +56,7 @@ class Namer(object):
while new_name in self.global_namespace:
n += 1
new_name = '%s_%d' % (new_name_root, n)
+
if live_object is not None:
self.renamed_calls[live_object] = new_name
self.generated_names.add(new_name)
@@ -67,7 +70,9 @@ class Namer(object):
if live_object is not None and live_object in self.renamed_calls:
return self.renamed_calls[live_object]
- if owner_type is None:
+ if not self.recursive:
+ new_name = original_name
+ elif owner_type is None or owner_type in self.partial_types:
# Top level functions: rename
new_name_root = 'tf__%s' % original_name
new_name = new_name_root
diff --git a/tensorflow/contrib/py2tf/naming_test.py b/tensorflow/contrib/py2tf/naming_test.py
index 9403d9ae1f..7bfc9b8733 100644
--- a/tensorflow/contrib/py2tf/naming_test.py
+++ b/tensorflow/contrib/py2tf/naming_test.py
@@ -28,7 +28,7 @@ class NamerTest(test.TestCase):
def bar():
pass
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('tf__foo', namer.compiled_function_name('foo'))
self.assertEqual('tf__bar', namer.compiled_function_name('bar', bar))
self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls)
@@ -38,7 +38,7 @@ class NamerTest(test.TestCase):
def foo():
pass
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
self.assertEqual('tf__foo', namer.compiled_function_name('foo', foo))
@@ -46,22 +46,22 @@ class NamerTest(test.TestCase):
def foo():
pass
- namer = naming.Namer(set(('tf__foo',)))
+ namer = naming.Namer({'tf__foo': 1}, True, None, ())
self.assertEqual('tf__foo_1', namer.compiled_function_name('foo', foo))
def test_new_symbol_tracks_names(self):
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('temp', namer.new_symbol('temp', set()))
self.assertItemsEqual(('temp',), namer.generated_names)
def test_new_symbol_avoids_duplicates(self):
- namer = naming.Namer(set())
+ namer = naming.Namer({}, True, None, ())
self.assertEqual('temp', namer.new_symbol('temp', set()))
self.assertEqual('temp_1', namer.new_symbol('temp', set()))
self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names)
def test_new_symbol_avoids_conflicts(self):
- namer = naming.Namer(set(('temp',)))
+ namer = naming.Namer({'temp': 1}, True, None, ())
# temp is reserved in the global namespace
self.assertEqual('temp_1', namer.new_symbol('temp', set()))
# temp_2 is reserved in the local namespace
diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD
index b60ed918f5..e0331dbc97 100644
--- a/tensorflow/contrib/py2tf/pyct/BUILD
+++ b/tensorflow/contrib/py2tf/pyct/BUILD
@@ -20,9 +20,11 @@ py_library(
"__init__.py",
"anno.py",
"compiler.py",
+ "context.py",
"parser.py",
"pretty_printer.py",
"templates.py",
+ "transformer.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py
new file mode 100644
index 0000000000..73f3613d09
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/context.py
@@ -0,0 +1,42 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Conversion context containers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class EntityContext(object):
+ """Contains information about an entity, like source code.
+
+ Attributes:
+ namer: Namer that matches the contract of all converters.
+ source_code: The entity's source code.
+ source_file: The entity's source file.
+ namespace: Dict[str->*], containing symbols visible to the entity
+ (excluding parameters).
+ arg_values: Dict[str->*], containing parameter values, if known.
+ arg_types: Dict[str->*], containing parameter types, if known.
+ """
+
+ def __init__(self, namer, source_code, source_file, namespace, arg_values,
+ arg_types):
+ self.namer = namer
+ self.source_code = source_code
+ self.source_file = source_file
+ self.namespace = namespace
+ self.arg_values = {} if arg_values is None else arg_values
+ self.arg_types = {} if arg_types is None else arg_types
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
index 3e54590326..0042aa90ed 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
@@ -24,6 +24,7 @@ from __future__ import print_function
import gast
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -69,7 +70,7 @@ class Scope(object):
raise KeyError(name)
-class TypeInfoResolver(gast.NodeTransformer):
+class TypeInfoResolver(transformer.Base):
"""Annotates symbols with type information where possible.
Nodes currently annotated:
@@ -77,9 +78,9 @@ class TypeInfoResolver(gast.NodeTransformer):
* Attribute (helps resolve object methods)
"""
- def __init__(self, value_hints):
+ def __init__(self, context):
+ super(TypeInfoResolver, self).__init__(context)
self.scope = Scope(None)
- self.value_hints = value_hints
self.function_level = 0
def visit_FunctionDef(self, node):
@@ -120,13 +121,11 @@ class TypeInfoResolver(gast.NodeTransformer):
self.generic_visit(node)
if isinstance(node.ctx, gast.Param):
self.scope.setval(node.id, gast.Name(node.id, gast.Load(), None))
- # TODO(mdan): Member functions should not need type hints.
- # We could attemp to extract im_class from the live_val annotation.
- if self.function_level == 1 and node.id in self.value_hints:
+ if self.function_level == 1 and node.id in self.context.arg_types:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
type_holder = gast.Name(node.id, gast.Load(), None)
- type_string, type_obj = self.value_hints[node.id]
+ type_string, type_obj = self.context.arg_types[node.id]
anno.setanno(type_holder, 'type', type_obj)
anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
self.scope.setval(node.id, type_holder)
@@ -206,6 +205,5 @@ class TypeInfoResolver(gast.NodeTransformer):
return node
-def resolve(node, value_hints):
- assert value_hints is not None
- return TypeInfoResolver(value_hints).visit(node)
+def resolve(node, context):
+ return TypeInfoResolver(context).visit(node)
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
index 8526f42413..a491f49ca3 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
@@ -19,7 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import context
from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct import transformer
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
@@ -54,17 +56,27 @@ class ScopeTest(test.TestCase):
class TypeInfoResolverTest(test.TestCase):
+ def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
+ ctx = context.EntityContext(
+ namer=None,
+ source_code=None,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types)
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, ctx)
+ return node
+
def test_constructor_detection(self):
def test_fn():
opt = training.GradientDescentOptimizer(0.1)
return opt
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- node = type_info.resolve(node, {})
-
+ node = self._parse_and_analyze(test_fn, {'training': training})
call_node = node.body[0].body[0].value
self.assertEquals(training.GradientDescentOptimizer,
anno.getanno(call_node, 'type'))
@@ -77,11 +89,7 @@ class TypeInfoResolverTest(test.TestCase):
opt = training.GradientDescentOptimizer(0.1)
opt.minimize(0)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- node = type_info.resolve(node, {})
-
+ node = self._parse_and_analyze(test_fn, {'training': training})
attr_call_node = node.body[0].body[1].value.func
self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
anno.getanno(attr_call_node, 'type_fqn'))
@@ -92,11 +100,7 @@ class TypeInfoResolverTest(test.TestCase):
with session.Session() as sess:
sess.run(x)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'session': session}, {})
- node = type_info.resolve(node, {})
-
+ node = self._parse_and_analyze(test_fn, {'session': session})
constructor_call = node.body[0].body[0].items[0].context_expr
self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
self.assertEquals((session.__name__, 'Session'),
@@ -115,33 +119,25 @@ class TypeInfoResolverTest(test.TestCase):
opt = training.GradientDescentOptimizer(0.01)
opt.minimize(0)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- with self.assertRaises(ValueError):
- node = type_info.resolve(node, {})
+ with self.assertRaises(transformer.PyFlowParseError):
+ self._parse_and_analyze(test_fn, {'training': training})
def test_parameter_class_members(self):
def test_fn(opt):
opt.minimize(0)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- with self.assertRaises(ValueError):
- node = type_info.resolve(node, {})
+ with self.assertRaises(transformer.PyFlowParseError):
+ self._parse_and_analyze(test_fn, {'training': training})
def test_parameter_class_members_with_value_hints(self):
def test_fn(opt):
opt.minimize(0)
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- node = type_info.resolve(
- node, {
+ node = self._parse_and_analyze(
+ test_fn, {'training': training},
+ arg_types={
'opt': (('%s.GradientDescentOptimizer' % training.__name__),
training.GradientDescentOptimizer(0.1))
})
@@ -160,11 +156,8 @@ class TypeInfoResolverTest(test.TestCase):
foo = bar
foo()
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'bar': bar}, {})
- with self.assertRaises(ValueError):
- node = type_info.resolve(node, {})
+ with self.assertRaises(transformer.PyFlowParseError):
+ self._parse_and_analyze(test_fn, {'bar': bar})
def test_nested_members(self):
@@ -172,11 +165,8 @@ class TypeInfoResolverTest(test.TestCase):
foo = training.GradientDescentOptimizer(0.1)
foo.bar.baz()
- node = parser.parse_object(test_fn)
- node = access.resolve(node)
- node = live_values.resolve(node, {'training': training}, {})
- with self.assertRaises(ValueError):
- node = type_info.resolve(node, {})
+ with self.assertRaises(transformer.PyFlowParseError):
+ self._parse_and_analyze(test_fn, {'training': training})
if __name__ == '__main__':
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
index 4fadc793e6..77c5fbe02a 100644
--- a/tensorflow/contrib/py2tf/pyct/templates.py
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -80,37 +80,46 @@ class ReplaceTransformer(gast.NodeTransformer):
return node
+def _strings_to_names(n):
+ if isinstance(n, str):
+ # Note: the node will receive the ctx value from the template, see
+ # ReplaceTransformer.visit_Name.
+ return gast.Name(id=n, ctx=None, annotation=None)
+ if isinstance(n, list):
+ return [_strings_to_names(e) for e in n]
+ if isinstance(n, tuple):
+ return tuple(_strings_to_names(e) for e in n)
+ return n
+
+
def replace(template, **replacements):
"""Replace placeholders in a Python template.
+ AST Name and Tuple nodes always receive the context that inferred from
+ the template. However, when replacing more complex nodes (that can potentially
+ contain Name children), then the caller is responsible for setting the
+ appropriate context.
+
Args:
- template: A function to be used as a template. Any placeholder is expected
- to also be a function argument.
+ template: A string representing Python code. Any symbol name can be used
+ that appears in the template code can be used as placeholder.
**replacements: A mapping from placeholder names to (lists of) AST nodes
- that these placeholders will be replaced by.
+ that these placeholders will be replaced by. String values are also
+ supported as a shorthand for AST Name nodes with the respective ID.
Returns:
- body: An AST node or list of AST nodes with the replacements made. If the
- template was a function, a list will be returned. If the template was a
- node, the same node will be returned. If the template was a string, an
- AST node will be returned (a `Module` node in the case of a multi-line
- string, an `Expr` node otherwise).
+ An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
Raises:
- ValueError: If a function is used as a template and an incorrect set of
- replacements was passed.
+ ValueError: if the arguments are incorrect.
"""
- tree = parser.parse_object(template).body[0]
- placeholders = set(arg.id for arg in tree.args.args)
- tree.args.args = []
- if tree.args.vararg:
- placeholders.add(tree.args.vararg)
- tree.args.vararg = None
- if set(replacements.keys()) != placeholders:
- raise ValueError(
- 'too many or few replacements. replacements: %s; placeholders: %s' %
- (replacements.keys(), placeholders))
-
- # Perform the replacement, stripping the function into which the template was
- # wrapped.
+ if not isinstance(template, str):
+ raise ValueError('Expected string template, got %s' % type(template))
+ tree = parser.parse_str(template)
+ for k in replacements:
+ replacements[k] = _strings_to_names(replacements[k])
return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py
index 2ad8b9317b..1143131283 100644
--- a/tensorflow/contrib/py2tf/pyct/templates_test.py
+++ b/tensorflow/contrib/py2tf/pyct/templates_test.py
@@ -28,46 +28,42 @@ from tensorflow.python.platform import test
class TemplatesTest(test.TestCase):
def test_replace_variable(self):
- def template(a): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
+ template = """
+ def test_fn(a):
a += 1
a = 2 * a + 1
- return b # pylint:disable=undefined-variable
+ return b
+ """
- node = templates.replace(
- template, a=gast.Name('b', gast.Load(), None))[0]
+ node = templates.replace(template, a='b')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self):
- def template(fname): # pylint:disable=unused-argument
- def fname(a): # pylint:disable=function-redefined
+ template = """
+ def fname(a):
a += 1
a = 2 * a + 1
return a
+ """
- node = templates.replace(
- template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+ node = templates.replace(template, fname='test_fn')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_code_block(self):
- def template(block): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
- block # pylint:disable=pointless-statement
+ template = """
+ def test_fn(a):
+ block
return a
+ """
node = templates.replace(
template,
block=[
- gast.Assign(
- [
- gast.Name('a', gast.Store(), None)
- ],
- gast.BinOp(
- gast.Name('a', gast.Load(), None),
- gast.Add(),
- gast.Num(1))),
+ gast.Assign([
+ gast.Name('a', None, None)
+ ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0]
result = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))
diff --git a/tensorflow/contrib/py2tf/pyct/transformer.py b/tensorflow/contrib/py2tf/pyct/transformer.py
new file mode 100644
index 0000000000..d5aa23eaeb
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/transformer.py
@@ -0,0 +1,58 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A node transformer that includes utilities for SCT."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import pretty_printer
+
+
+class PyFlowParseError(SyntaxError):
+ pass
+
+
+class Base(gast.NodeTransformer):
+ """Base class for specialized transformers."""
+
+ def __init__(self, context):
+ """Initialize the transformer. Subclasses should call this.
+
+ Args:
+ context: An EntityContext.
+ """
+ self._lineno = 0
+ self._col_offset = 0
+ self.context = context
+
+ def visit(self, node):
+ try:
+ source_code = self.context.source_code
+ source_file = self.context.source_file
+ if source_code and hasattr(node, 'lineno'):
+ self._lineno = node.lineno
+ self._col_offset = node.col_offset
+ return super(Base, self).visit(node)
+ except ValueError as e:
+ msg = '%s\nOccurred at node:\n%s' % (str(e), pretty_printer.fmt(node))
+ if source_code:
+ line = self._source.splitlines()[self._lineno - 1]
+ else:
+ line = '<no source available>'
+ raise PyFlowParseError(
+ msg, (source_file, self._lineno, self._col_offset + 1, line))
diff --git a/tensorflow/contrib/quantize/__init__.py b/tensorflow/contrib/quantize/__init__.py
index 5d4e4575c9..933200e607 100644
--- a/tensorflow/contrib/quantize/__init__.py
+++ b/tensorflow/contrib/quantize/__init__.py
@@ -27,6 +27,8 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
"create_eval_graph",
"create_training_graph",
+ "experimental_create_eval_graph",
+ "experimental_create_training_graph",
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index d647bb94e8..bbd9743d80 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -128,3 +128,67 @@ def create_eval_graph(input_graph, elements=None, device_name_or_function=None):
is_training=False,
elements=elements,
device_name_or_function=device_name_or_function)
+
+
+def experimental_create_training_graph(input_graph,
+ elements=None,
+ device_name_or_function=None):
+ """Returns a transformed training input_graph for simulated quantization.
+
+ This function has additional experimental options not (yet) available to
+ create_training_graph. The resulting behavior may be undefined.
+ The forward pass has fake quantization ops inserted to simulate the error
+ introduced by quantization.
+
+ Args:
+ input_graph: The tf.Graph to be transformed.
+ elements: (Optional) List of Tensors and Operations in input_graph whose
+ corresponding elements in the new graph will be returned.
+ device_name_or_function: (Optional) The device name or function to use.
+
+ Returns:
+ g is new tf.Graph that is rewritten for simulated quantization.
+ l is a list of Tensors/Operations in g corresponding to the provided input
+ elements, if elements is not None.
+
+ Raises:
+ ValueError: If elements contains an element that isn't a tf.Tensor or
+ tf.Operation.
+ """
+ return _create_graph(
+ input_graph=input_graph,
+ is_training=True,
+ elements=elements,
+ device_name_or_function=device_name_or_function)
+
+
+def experimental_create_eval_graph(input_graph,
+ elements=None,
+ device_name_or_function=None):
+ """Returns a transformed eval input_graph for simulated quantization.
+
+ This function has additional experimental options not (yet) available to
+ create_eval_graph. The resulting behavior may be undefined.
+ The forward pass has fake quantization ops inserted to simulate the error
+ introduced by quantization.
+
+ Args:
+ input_graph: The tf.Graph to be transformed.
+ elements: (Optional) List of Tensors and Operations in input_graph whose
+ corresponding elements in the new graph will be returned.
+ device_name_or_function: (Optional) The device name or function to use.
+
+ Returns:
+ g is new tf.Graph that is rewritten for simulated quantization.
+ l is a list of Tensors/Operations in g corresponding to the provided input
+ elements, if elements is not None.
+
+ Raises:
+ ValueError: If elements contains an element that isn't a tf.Tensor or
+ tf.Operation.
+ """
+ return _create_graph(
+ input_graph=input_graph,
+ is_training=False,
+ elements=elements,
+ device_name_or_function=device_name_or_function)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index 3407ace391..514862a0ab 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -31,28 +31,30 @@ from tensorflow.python.platform import googletest
class QuantizeGraphTest(test_util.TensorFlowTestCase):
-
# We have a lot of other tests that test the details of the rewrite, here we
# just the specific features of the quantize_graph API.
- def testReturnedElementsTraining(self):
- self._TestReturnElements(True)
- def testReturnedElementsEval(self):
- self._TestReturnElements(False)
+ def _RunTestOverParameters(self, test_fn):
+ rewrite_fns = [
+ quantize_graph.create_training_graph,
+ quantize_graph.create_eval_graph,
+ quantize_graph.experimental_create_training_graph,
+ quantize_graph.experimental_create_eval_graph,
+ ]
+ for fn in rewrite_fns:
+ test_fn(fn)
+
+ def testReturnedElements(self):
+ self._RunTestOverParameters(self._TestReturnElements)
- def _TestReturnElements(self, is_training):
+ def _TestReturnElements(self, fn):
graph = ops.Graph()
with graph.as_default():
a = constant_op.constant(1.0)
b = variables.Variable(2.0)
c = a + b
elements = [a, b, c.op]
- if is_training:
- q_graph, returned_elements = quantize_graph.create_training_graph(
- graph, elements=elements)
- else:
- q_graph, returned_elements = quantize_graph.create_eval_graph(
- graph, elements=elements)
+ q_graph, returned_elements = fn(graph, elements=elements)
# Make sure q_graph is different from graph.
self.assertTrue(graph != q_graph)
# Check that the returned elements are part of the new graph.
@@ -62,35 +64,26 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
for element, returned_element in zip(elements, returned_elements):
self.assertEqual(element.name, returned_element.name)
- def testNoReturnElementsTraining(self):
- self._TestNoReturnElements(True)
+ def testNoReturnElements(self):
+ self._RunTestOverParameters(self._TestNoReturnElements)
- def testNoReturnElementsEval(self):
- self._TestNoReturnElements(False)
-
- def _TestNoReturnElements(self, is_training):
+ def _TestNoReturnElements(self, fn):
graph = ops.Graph()
with graph.as_default():
a = constant_op.constant(1.0)
b = variables.Variable(2.0)
_ = a + b
- if is_training:
- q_graph = quantize_graph.create_training_graph(graph)
- else:
- q_graph = quantize_graph.create_eval_graph(graph)
+ q_graph = fn(graph)
# Check that quantize_graph didn't return a tuple when elements isn't
# provided.
self.assertTrue(isinstance(q_graph, ops.Graph))
# Make sure q_graph is different from graph.
self.assertTrue(graph != q_graph)
- def testDeviceNameTraining(self):
- self._TestDeviceName(True)
-
- def testDeviceNameEval(self):
- self._TestDeviceName(False)
+ def testDeviceName(self):
+ self._RunTestOverParameters(self._TestDeviceName)
- def _TestDeviceName(self, is_training):
+ def _TestDeviceName(self, fn):
graph = ops.Graph()
with graph.as_default():
batch_size, height, width, depth = 5, 128, 128, 3
@@ -106,12 +99,7 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
_ = nn_ops.relu6(conv)
device_name = '/job:oink/task:0/device:CPU:0'
- if is_training:
- q_graph = quantize_graph.create_training_graph(
- graph, device_name_or_function=device_name)
- else:
- q_graph = quantize_graph.create_eval_graph(
- graph, device_name_or_function=device_name)
+ q_graph = fn(graph, device_name_or_function=device_name)
orig_variable_names = set(
[v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
index b2360fec6c..0388079f20 100644
--- a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
+++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py
@@ -61,7 +61,7 @@ def _compute_output_resolution(input_spatial_resolution, kernel_size, stride,
stride: Stride (int).
total_padding: Total padding to be applied (int).
Returns:
- output_resolution: Ouput dimension (int) or None.
+ output_resolution: Output dimension (int) or None.
"""
if (input_spatial_resolution is None) or (kernel_size is None) or (
stride is None) or (total_padding is None):
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
index fc3a2da9b3..9bb1724a2c 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -81,4 +81,4 @@ CALL_ALL_REDUCEOPS(ReduceSliceFunctorReduceop)
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
index 8e6870fadd..501cddb8c8 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -34,9 +34,9 @@ namespace functor {
__global__ void ReduceSliceDeviceKernel##reduceop( \
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
const T begin, const Index *indices, const T *input, T *out) { \
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
Index outidx = x * config.virtual_thread_count.y * \
config.virtual_thread_count.z + \
y * config.virtual_thread_count.z + z; \
@@ -68,8 +68,9 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
- Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
- ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
+ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
+ 0, 0); \
\
ReduceSliceDeviceKernel##reduceop<T, Index> \
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.h b/tensorflow/contrib/resampler/kernels/resampler_ops.h
index 8258ecaf5d..85d3676efa 100644
--- a/tensorflow/contrib/resampler/kernels/resampler_ops.h
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops.h
@@ -13,8 +13,8 @@
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
+#define TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
#if PLATFORM_WINDOWS
#define __restrict__ __restrict
@@ -64,5 +64,4 @@ struct ResamplerGrad2DFunctor{
} // namespace functor
} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
+#endif // TENSORFLOW_CONTRIB_RESAMPLER_KERNELS_RESAMPLER_OPS_H_
diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h
index e33eceadff..a52c934233 100644
--- a/tensorflow/contrib/rnn/kernels/blas_gemm.h
+++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
+#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
+#define TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -74,4 +74,4 @@ struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
+#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_BLAS_GEMM_H_
diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h
index 06a5665062..3e2cb39e64 100644
--- a/tensorflow/contrib/rnn/kernels/gru_ops.h
+++ b/tensorflow/contrib/rnn/kernels/gru_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
+#define TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
@@ -181,4 +181,4 @@ struct GRUBlockCellBprop : public GRUCell {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
+#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_GRU_OPS_H_
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h
index 1906581b16..bc6b85f3f1 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops.h
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
+#define TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/rnn/kernels/blas_gemm.h"
@@ -291,4 +291,4 @@ struct BlockLSTMBprop : public LSTMBlockCell {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
+#endif // TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index b5d81b7caa..cafeb56ad8 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -663,6 +663,12 @@ class DropoutWrapperTest(test.TestCase):
self.assertEqual(res[1].h.shape, (batch_size, 3))
return res
+ def testWrappedCellProperty(self):
+ cell = rnn_cell_impl.BasicRNNCell(10)
+ wrapper = rnn_cell_impl.DropoutWrapper(cell)
+ # Github issue 15810
+ self.assertEqual(wrapper.wrapped_cell, cell)
+
def testDropoutWrapperKeepAllConstantInput(self):
keep = array_ops.ones([])
res = self._testDropoutWrapper(
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 73789206f3..8a3894ef9d 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -53,14 +53,12 @@ class RNNCellTest(test.TestCase):
batch_size = 3
input_size = 4
expected_output = np.array(
- [[0.121753, 0.121753],
- [0.103349, 0.103349],
- [0.100178, 0.100178]],
+ [[0.121753, 0.121753], [0.103349, 0.103349], [0.100178, 0.100178]],
dtype=np.float32)
expected_state = np.array(
- [[0.137523, 0.137523, 0.121753, 0.121753],
- [0.105450, 0.105450, 0.103349, 0.103349],
- [0.100742, 0.100742, 0.100178, 0.100178]],
+ [[0.137523, 0.137523, 0.121753, 0.121753], [
+ 0.105450, 0.105450, 0.103349, 0.103349
+ ], [0.100742, 0.100742, 0.100178, 0.100178]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -69,14 +67,14 @@ class RNNCellTest(test.TestCase):
output, state = contrib_rnn_cell.CoupledInputForgetGateLSTMCell(
num_units=num_units, forget_bias=1.0, state_is_tuple=False)(x, m)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- m.name:
- 0.1 * np.ones((batch_size, state_size))
- })
+ res = sess.run(
+ [output, state], {
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ m.name:
+ 0.1 * np.ones((batch_size, state_size))
+ })
# This is a smoke test: Only making sure expected values didn't change.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], expected_output)
@@ -101,14 +99,14 @@ class RNNCellTest(test.TestCase):
frequency_skip=frequency_skip,
forget_bias=1.0)(x, m)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- m.name:
- 0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
- })
+ res = sess.run(
+ [output, state], {
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ m.name:
+ 0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
+ })
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
@@ -141,17 +139,14 @@ class RNNCellTest(test.TestCase):
state_is_tuple=True)
inputs = constant_op.constant(
np.array(
- [[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -198,11 +193,10 @@ class RNNCellTest(test.TestCase):
dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * total_blocks))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * total_blocks))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -230,20 +224,28 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
- 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699],
- [0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
- 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171],
- [0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
- 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250]],
+ [[
+ 0.416383, 0.416383, 0.403238, 0.403238, 0.524020, 0.524020,
+ 0.565425, 0.565425, 0.557865, 0.557865, 0.609699, 0.609699
+ ], [
+ 0.627331, 0.627331, 0.622393, 0.622393, 0.688342, 0.688342,
+ 0.708078, 0.708078, 0.694245, 0.694245, 0.715171, 0.715171
+ ], [
+ 0.711050, 0.711050, 0.709197, 0.709197, 0.736533, 0.736533,
+ 0.744264, 0.744264, 0.737390, 0.737390, 0.745250, 0.745250
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
- 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865],
- [0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
- 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245],
- [0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
- 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390]],
+ [[
+ 0.625556, 0.625556, 0.416383, 0.416383, 0.759134, 0.759134,
+ 0.524020, 0.524020, 0.798795, 0.798795, 0.557865, 0.557865
+ ], [
+ 0.875488, 0.875488, 0.627331, 0.627331, 0.936432, 0.936432,
+ 0.688342, 0.688342, 0.941961, 0.941961, 0.694245, 0.694245
+ ], [
+ 0.957327, 0.957327, 0.711050, 0.711050, 0.979522, 0.979522,
+ 0.736533, 0.736533, 0.980245, 0.980245, 0.737390, 0.737390
+ ]],
dtype=np.float32)
for state_is_tuple in [False, True]:
with self.test_session() as sess:
@@ -259,18 +261,16 @@ class RNNCellTest(test.TestCase):
couple_input_forget_gates=True,
state_is_tuple=state_is_tuple)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
if state_is_tuple:
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts))
else:
init_state = constant_op.constant(
0.1 * np.ones(
@@ -302,32 +302,40 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
- 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
- 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
- 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218],
- [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
- 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
- 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
- 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840],
- [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
- 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
- 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
- 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731]],
+ [[
+ 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+ 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+ 0.520789, 0.520789, 0.476968, 0.476968, 0.604341, 0.604341,
+ 0.760207, 0.760207, 0.635773, 0.635773, 0.850218, 0.850218
+ ], [
+ 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+ 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+ 0.692621, 0.692621, 0.652363, 0.652363, 0.737517, 0.737517,
+ 0.899558, 0.899558, 0.745984, 0.745984, 0.946840, 0.946840
+ ], [
+ 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+ 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+ 0.759940, 0.759940, 0.720652, 0.720652, 0.778552, 0.778552,
+ 0.941606, 0.941606, 0.781035, 0.781035, 0.977731, 0.977731
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
- 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
- 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
- 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773],
- [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
- 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
- 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
- 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984],
- [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
- 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
- 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
- 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035]],
+ [[
+ 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+ 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+ 0.785405, 0.785405, 0.520789, 0.520789, 0.890836, 0.890836,
+ 0.604341, 0.604341, 0.928512, 0.928512, 0.635773, 0.635773
+ ], [
+ 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+ 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+ 0.993088, 0.993088, 0.692621, 0.692621, 1.040288, 1.040288,
+ 0.737517, 0.737517, 1.048773, 1.048773, 0.745984, 0.745984
+ ], [
+ 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+ 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+ 1.062455, 1.062455, 0.759940, 0.759940, 1.080101, 1.080101,
+ 0.778552, 0.778552, 1.082402, 1.082402, 0.781035, 0.781035
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -339,17 +347,16 @@ class RNNCellTest(test.TestCase):
forget_bias=1.0,
num_frequency_blocks=[num_shifts])
inputs = constant_op.constant(
- np.array([[1.0, 1.1, 1.2, 1.3],
- [2.0, 2.1, 2.2, 2.3],
- [3.0, 3.1, 3.2, 3.3]],
- dtype=np.float32),
+ np.array(
+ [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+ [3.0, 3.1, 3.2, 3.3]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts * 2))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts * 2))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -375,32 +382,40 @@ class RNNCellTest(test.TestCase):
frequency_skip = 1
num_shifts = int((input_size - feature_size) / frequency_skip + 1)
expected_output = np.array(
- [[0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
- 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
- 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
- 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071],
- [0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
- 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
- 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
- 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958],
- [0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
- 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
- 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
- 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858]],
+ [[
+ 0.464130, 0.464130, 0.419165, 0.419165, 0.593283, 0.593283,
+ 0.738350, 0.738350, 0.661638, 0.661638, 0.866774, 0.866774,
+ 0.322645, 0.322645, 0.276068, 0.276068, 0.584654, 0.584654,
+ 0.690292, 0.690292, 0.640446, 0.640446, 0.840071, 0.840071
+ ], [
+ 0.669636, 0.669636, 0.628966, 0.628966, 0.736057, 0.736057,
+ 0.895927, 0.895927, 0.755559, 0.755559, 0.954359, 0.954359,
+ 0.493625, 0.493625, 0.449236, 0.449236, 0.730828, 0.730828,
+ 0.865996, 0.865996, 0.749429, 0.749429, 0.944958, 0.944958
+ ], [
+ 0.751109, 0.751109, 0.711716, 0.711716, 0.778357, 0.778357,
+ 0.940779, 0.940779, 0.784530, 0.784530, 0.980604, 0.980604,
+ 0.608587, 0.608587, 0.566683, 0.566683, 0.777345, 0.777345,
+ 0.925820, 0.925820, 0.782597, 0.782597, 0.976858, 0.976858
+ ]],
dtype=np.float32)
expected_state = np.array(
- [[0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
- 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
- 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
- 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446],
- [0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
- 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
- 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
- 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429],
- [1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
- 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
- 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
- 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597]],
+ [[
+ 0.710660, 0.710660, 0.464130, 0.464130, 0.877293, 0.877293,
+ 0.593283, 0.593283, 0.958505, 0.958505, 0.661638, 0.661638,
+ 0.516575, 0.516575, 0.322645, 0.322645, 0.866628, 0.866628,
+ 0.584654, 0.584654, 0.934002, 0.934002, 0.640446, 0.640446
+ ], [
+ 0.967579, 0.967579, 0.669636, 0.669636, 1.038811, 1.038811,
+ 0.736057, 0.736057, 1.058201, 1.058201, 0.755559, 0.755559,
+ 0.749836, 0.749836, 0.493625, 0.493625, 1.033488, 1.033488,
+ 0.730828, 0.730828, 1.052186, 1.052186, 0.749429, 0.749429
+ ], [
+ 1.053842, 1.053842, 0.751109, 0.751109, 1.079919, 1.079919,
+ 0.778357, 0.778357, 1.085620, 1.085620, 0.784530, 0.784530,
+ 0.895999, 0.895999, 0.608587, 0.608587, 1.078978, 1.078978,
+ 0.777345, 0.777345, 1.083843, 1.083843, 0.782597, 0.782597
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -413,17 +428,16 @@ class RNNCellTest(test.TestCase):
num_frequency_blocks=[num_shifts],
backward_slice_offset=1)
inputs = constant_op.constant(
- np.array([[1.0, 1.1, 1.2, 1.3],
- [2.0, 2.1, 2.2, 2.3],
- [3.0, 3.1, 3.2, 3.3]],
- dtype=np.float32),
+ np.array(
+ [[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3],
+ [3.0, 3.1, 3.2, 3.3]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- init_state = cell.state_tuple_type(
- *([state_value, state_value] * num_shifts * 2))
+ init_state = cell.state_tuple_type(*(
+ [state_value, state_value] * num_shifts * 2))
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state])
@@ -474,8 +488,8 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
with self.test_session() as sess:
- with variable_scope.variable_scope("state_is_tuple_" + str(
- state_is_tuple)):
+ with variable_scope.variable_scope(
+ "state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
cell = contrib_rnn_cell.AttentionCellWrapper(
@@ -525,16 +539,15 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
with self.test_session() as sess:
- with variable_scope.variable_scope("state_is_tuple_" + str(
- state_is_tuple)):
+ with variable_scope.variable_scope(
+ "state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
cell = contrib_rnn_cell.AttentionCellWrapper(
lstm_cell, attn_length, state_is_tuple=state_is_tuple)
if state_is_tuple:
zeros = constant_op.constant(
- 0.1 * np.ones(
- [batch_size, num_units], dtype=np.float32),
+ 0.1 * np.ones([batch_size, num_units], dtype=np.float32),
dtype=dtypes.float32)
attn_state_zeros = constant_op.constant(
0.1 * np.ones(
@@ -579,22 +592,25 @@ class RNNCellTest(test.TestCase):
[1.018088, 0.378983, -0.572179, 0.268591]],
dtype=np.float32)
expected_state = np.array(
- [[0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
- 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
- 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
- 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
- 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
- 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
- 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
- 0.51843399],
- [0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
- 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
- 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
- 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
- 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
- 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
- 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
- 0.70582712]],
+ [[
+ 0.74946702, 0.34681597, 0.26474735, 1.06485605, 0.38465962,
+ 0.11420801, 0.10272158, 0.30925757, 0.63899988, 0.7181077,
+ 0.47534478, 0.33715725, 0.58086717, 0.49446869, 0.7641536,
+ 0.12814975, 0.92231739, 0.89857256, 0.21889746, 0.38442063,
+ 0.53481543, 0.8876909, 0.45823169, 0.5905602, 0.78038228,
+ 0.56501579, 0.03971386, 0.09870267, 0.8074435, 0.66821432,
+ 0.99211812, 0.12295902, 1.14606023, 0.34370938, -0.79251152,
+ 0.51843399
+ ], [
+ 0.5179342, 0.48682183, -0.25426468, 0.96810579, 0.28809637,
+ 0.13607743, -0.11446252, 0.26792109, 0.78047138, 0.63460857,
+ 0.49122369, 0.52007174, 0.73000264, 0.66986895, 0.73576689,
+ 0.86301267, 0.87887371, 0.35185754, 0.93417215, 0.64732957,
+ 0.63173044, 0.66627824, 0.53644657, 0.20477486, 0.98458421,
+ 0.38277245, 0.03746676, 0.92510188, 0.57714164, 0.84932971,
+ 0.36127412, 0.12125921, 1.1362772, 0.34361625, -0.78150457,
+ 0.70582712
+ ]],
dtype=np.float32)
seed = 12345
random_seed.set_random_seed(seed)
@@ -602,7 +618,8 @@ class RNNCellTest(test.TestCase):
for state_is_tuple in [False, True]:
with session.Session() as sess:
with variable_scope.variable_scope(
- "state_is_tuple", reuse=state_is_tuple,
+ "state_is_tuple",
+ reuse=state_is_tuple,
initializer=init_ops.glorot_uniform_initializer()):
lstm_cell = rnn_cell.BasicLSTMCell(
num_units, state_is_tuple=state_is_tuple)
@@ -646,36 +663,31 @@ class RNNCellTest(test.TestCase):
def testNASCell(self):
num_units = 6
batch_size = 3
- expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751,
- 0.576751, 0.576751],
- [0.618936, 0.618936, 0.618936, 0.618936,
- 0.618936, 0.618936],
- [0.627393, 0.627393, 0.627393, 0.627393,
- 0.627393, 0.627393]])
- expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772,
- 0.71579772, 0.71579772, 0.57675087, 0.57675087,
- 0.57675087, 0.57675087, 0.57675087, 0.57675087],
- [0.78041625, 0.78041625, 0.78041625, 0.78041625,
- 0.78041625, 0.78041625, 0.6189357, 0.6189357,
- 0.61893570, 0.6189357, 0.6189357, 0.6189357],
- [0.79457647, 0.79457647, 0.79457647, 0.79457647,
- 0.79457653, 0.79457653, 0.62739348, 0.62739348,
- 0.62739348, 0.62739348, 0.62739348, 0.62739348]
- ])
+ expected_output = np.array(
+ [[0.576751, 0.576751, 0.576751, 0.576751, 0.576751, 0.576751],
+ [0.618936, 0.618936, 0.618936, 0.618936, 0.618936, 0.618936],
+ [0.627393, 0.627393, 0.627393, 0.627393, 0.627393, 0.627393]])
+ expected_state = np.array([[
+ 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772, 0.71579772,
+ 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087, 0.57675087
+ ], [
+ 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625, 0.78041625,
+ 0.6189357, 0.6189357, 0.61893570, 0.6189357, 0.6189357, 0.6189357
+ ], [
+ 0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
+ 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
+ ]])
with self.test_session() as sess:
with variable_scope.variable_scope(
- "nas_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
init_state = rnn_cell.LSTMStateTuple(state_value, state_value)
output, state = cell(inputs, init_state)
@@ -699,39 +711,34 @@ class RNNCellTest(test.TestCase):
num_units = 6
batch_size = 3
num_proj = 5
- expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418,
- 1.697418],
- [1.840037, 1.840037, 1.840037, 1.840037,
- 1.840037],
- [1.873985, 1.873985, 1.873985, 1.873985,
- 1.873985]])
- expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207,
- 0.69855207, 0.69855207, 1.69741797, 1.69741797,
- 1.69741797, 1.69741797, 1.69741797],
- [0.77073824, 0.77073824, 0.77073824, 0.77073824,
- 0.77073824, 0.77073824, 1.84003687, 1.84003687,
- 1.84003687, 1.84003687, 1.84003687],
- [0.78973997, 0.78973997, 0.78973997, 0.78973997,
- 0.78973997, 0.78973997, 1.87398517, 1.87398517,
- 1.87398517, 1.87398517, 1.87398517]])
+ expected_output = np.array(
+ [[1.697418, 1.697418, 1.697418, 1.697418,
+ 1.697418], [1.840037, 1.840037, 1.840037, 1.840037, 1.840037],
+ [1.873985, 1.873985, 1.873985, 1.873985, 1.873985]])
+ expected_state = np.array([[
+ 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207, 0.69855207,
+ 1.69741797, 1.69741797, 1.69741797, 1.69741797, 1.69741797
+ ], [
+ 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824, 0.77073824,
+ 1.84003687, 1.84003687, 1.84003687, 1.84003687, 1.84003687
+ ], [
+ 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
+ 1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
+ ]])
with self.test_session() as sess:
with variable_scope.variable_scope(
- "nas_proj_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
state_value_c = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
state_value_h = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_proj), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_proj), dtype=np.float32),
dtype=dtypes.float32)
init_state = rnn_cell.LSTMStateTuple(state_value_c, state_value_h)
output, state = cell(inputs, init_state)
@@ -755,24 +762,20 @@ class RNNCellTest(test.TestCase):
num_units = 2
batch_size = 3
expected_state_and_output = np.array(
- [[0.13752282, 0.13752282],
- [0.10545051, 0.10545051],
+ [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
with self.test_session() as sess:
with variable_scope.variable_scope(
- "ugrnn_cell_test",
- initializer=init_ops.constant_initializer(0.5)):
+ "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
@@ -786,13 +789,11 @@ class RNNCellTest(test.TestCase):
num_units = 2
batch_size = 3
expected_state = np.array(
- [[0.13752282, 0.13752282],
- [0.10545051, 0.10545051],
+ [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
expected_output = np.array(
- [[2.00431061, 2.00431061],
- [4.00060606, 4.00060606],
+ [[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
with self.test_session() as sess:
@@ -802,14 +803,12 @@ class RNNCellTest(test.TestCase):
cell = contrib_rnn_cell.IntersectionRNNCell(
num_units=num_units, num_in_proj=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
output, state = cell(inputs, init_state)
sess.run([variables.global_variables_initializer()])
@@ -824,19 +823,17 @@ class RNNCellTest(test.TestCase):
batch_size = 3
cell = contrib_rnn_cell.IntersectionRNNCell(num_units=num_units)
inputs = constant_op.constant(
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]],
- dtype=np.float32),
+ np.array(
+ [[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]],
+ dtype=np.float32),
dtype=dtypes.float32)
init_state = constant_op.constant(
- 0.1 * np.ones(
- (batch_size, num_units), dtype=np.float32),
+ 0.1 * np.ones((batch_size, num_units), dtype=np.float32),
dtype=dtypes.float32)
- with self.assertRaisesRegexp(
- ValueError, "Must have input size == output size for "
- "Intersection RNN. To fix, num_in_proj should "
- "be set to num_units at cell init."):
+ with self.assertRaisesRegexp(ValueError,
+ "Must have input size == output size for "
+ "Intersection RNN. To fix, num_in_proj should "
+ "be set to num_units at cell init."):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
@@ -845,13 +842,11 @@ class RNNCellTest(test.TestCase):
batch_size = 3
input_size = 4
expected_state_c = np.array(
- [[6.450831e-04, 4.697885e-04],
- [9.862894e-05, 7.212213e-04],
+ [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
[4.401947e-04, 9.143004e-04]],
dtype=np.float32)
expected_state_h = np.array(
- [[4.621217e-04, 3.365449e-04],
- [7.438179e-05, 5.439147e-04],
+ [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
[3.347936e-04, 6.953785e-04]],
dtype=np.float32)
with variable_scope.variable_scope(
@@ -864,14 +859,14 @@ class RNNCellTest(test.TestCase):
output, state = contrib_rnn_cell.PhasedLSTMCell(num_units=num_units)(
(t, x), state0)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- t.name:
- np.array([[1.], [2.], [3.]]),
- x.name:
- np.array([[1., 1., 1., 1.],
- [2., 2., 2., 2.],
- [3., 3., 3., 3.]]),
- })
+ res = sess.run(
+ [output, state], {
+ t.name:
+ np.array([[1.], [2.], [3.]]),
+ x.name:
+ np.array([[1., 1., 1., 1.], [2., 2., 2., 2.],
+ [3., 3., 3., 3.]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -880,36 +875,32 @@ class RNNCellTest(test.TestCase):
def testConv1DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,1]
+ shape = [2, 1]
filter_size = [3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[1.4375670191], [1.4375670191]],
- [[2.7542609292], [2.7542609292]]],
+ [[[1.4375670191], [1.4375670191]], [[2.7542609292], [2.7542609292]]],
dtype=np.float32)
expected_state_h = np.array(
- [[[0.6529865603], [0.6529865603]],
- [[0.8736877431], [0.8736877431]]],
+ [[[0.6529865603], [0.6529865603]], [[0.8736877431], [0.8736877431]]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/2.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 2.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, 1])
- cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv1DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- hidden[0].name:
- np.array([[[1.],[1.]],
- [[2.],[2.]]]),
- x.name:
- np.array([[[1.],[1.]],
- [[2.],[2.]]]),
- })
+ res = sess.run(
+ [output, state], {
+ hidden[0].name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+ x.name: np.array([[[1.], [1.]], [[2.], [2.]]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -918,44 +909,40 @@ class RNNCellTest(test.TestCase):
def testConv2DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,2,1]
- filter_size = [3,3]
+ shape = [2, 2, 1]
+ filter_size = [3, 3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]],
- [[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]]],
+ [[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]],
+ [[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+ ]],
dtype=np.float32)
expected_state_h = np.array(
- [[[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]],
- [[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]]],
+ [[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]],
+ [[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+ ]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/4.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 4.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
- cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv2DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- hidden[0].name:
- np.array([[[[1.],[1.]],
- [[1.],[1.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]),
- x.name:
- np.array([[[[1.],[1.]],
- [[1.],[1.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]),
- })
+ res = sess.run(
+ [output, state], {
+ hidden[0].name:
+ np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+ [[2.], [2.]]]]),
+ x.name:
+ np.array([[[[1.], [1.]], [[1.], [1.]]], [[[2.], [2.]],
+ [[2.], [2.]]]]),
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -964,36 +951,33 @@ class RNNCellTest(test.TestCase):
def testConv3DLSTMCell(self):
with self.test_session() as sess:
- shape = [2,2,2,1]
- filter_size = [3,3,3]
+ shape = [2, 2, 2, 1]
+ filter_size = [3, 3, 3]
num_features = 1
batch_size = 2
expected_state_c = np.array(
- [[[[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]],
- [[[1.4375670191], [1.4375670191]],
- [[1.4375670191], [1.4375670191]]]],
- [[[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]],
- [[[2.7542609292], [2.7542609292]],
- [[2.7542609292], [2.7542609292]]]]],
+ [[[[[1.4375670191], [1.4375670191]], [[1.4375670191], [1.4375670191]]
+ ], [[[1.4375670191], [1.4375670191]], [[1.4375670191],
+ [1.4375670191]]]],
+ [[[[2.7542609292], [2.7542609292]], [[2.7542609292], [2.7542609292]]
+ ], [[[2.7542609292], [2.7542609292]], [[2.7542609292],
+ [2.7542609292]]]]],
dtype=np.float32)
expected_state_h = np.array(
- [[[[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]],
- [[[0.6529865603], [0.6529865603]],
- [[0.6529865603], [0.6529865603]]]],
- [[[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]],
- [[[0.8736877431], [0.8736877431]],
- [[0.8736877431], [0.8736877431]]]]],
+ [[[[[0.6529865603], [0.6529865603]], [[0.6529865603], [0.6529865603]]
+ ], [[[0.6529865603], [0.6529865603]], [[0.6529865603],
+ [0.6529865603]]]],
+ [[[[0.8736877431], [0.8736877431]], [[0.8736877431], [0.8736877431]]
+ ], [[[0.8736877431], [0.8736877431]], [[0.8736877431],
+ [0.8736877431]]]]],
dtype=np.float32)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(1.0/8.0)):
+ "root", initializer=init_ops.constant_initializer(1.0 / 8.0)):
x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
- cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
- kernel_shape=filter_size,
- output_channels=num_features)
+ cell = contrib_rnn_cell.Conv3DLSTMCell(
+ input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
output, state = cell(x, hidden)
@@ -1056,8 +1040,8 @@ class RNNCellTest(test.TestCase):
num_units=num_units, number_of_groups=number_of_groups)
cell = rnn_cell.LSTMCell(num_units=num_units)
self.assertTrue(isinstance(gcell.state_size, tuple))
- zero_state = gcell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
gh, gs = gcell(x, zero_state)
h, g = cell(x, zero_state)
@@ -1080,16 +1064,16 @@ class RNNCellTest(test.TestCase):
glstm_input = array_ops.ones([batch_size, num_units])
gcell = contrib_rnn_cell.GLSTMCell(
num_units=num_units, number_of_groups=number_of_groups)
- gcell_zero_state = gcell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ gcell_zero_state = gcell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
gh, gs = gcell(glstm_input, gcell_zero_state)
# input for LSTM cell simulating single G-LSTM group
lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
# note division by number_of_groups. This cell one simulates G-LSTM group
cell = rnn_cell.LSTMCell(num_units=int(num_units / number_of_groups))
- cell_zero_state = cell.zero_state(batch_size=batch_size,
- dtype=dtypes.float32)
+ cell_zero_state = cell.zero_state(
+ batch_size=batch_size, dtype=dtypes.float32)
h, g = cell(lstm_input, cell_zero_state)
sess.run([variables.global_variables_initializer()])
@@ -1099,6 +1083,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
h_res, 1e-5)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
@@ -1119,13 +1104,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_state0_c = np.array([[-1.0, 1.0]])
@@ -1155,11 +1141,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2)
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1., 1.]]),
- c.name: 0.1 * np.asarray([[0, 1]]),
- h.name: 0.1 * np.asarray([[2, 3]]),
- })
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1., 1.]]),
+ c.name: 0.1 * np.asarray([[0, 1]]),
+ h.name: 0.1 * np.asarray([[2, 3]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_c = np.array([[-1.0, 1.0]])
@@ -1168,7 +1155,6 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_c, 1e-5)
self.assertAllClose(res[1].h, expected_h, 1e-5)
-
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
with self.test_session() as sess:
@@ -1186,19 +1172,20 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
cell = rnn_cell.MultiRNNCell([single_cell() for _ in range(2)])
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
- expected_h = np.array([[ 0.70230919, 0.72581059]])
- expected_state0_c = np.array([[ 0.8020075, 0.89599884]])
- expected_state0_h = np.array([[ 0.56668288, 0.60858738]])
- expected_state1_c = np.array([[ 1.17500675, 1.26892781]])
- expected_state1_h = np.array([[ 0.70230919, 0.72581059]])
+ expected_h = np.array([[0.70230919, 0.72581059]])
+ expected_state0_c = np.array([[0.8020075, 0.89599884]])
+ expected_state0_h = np.array([[0.56668288, 0.60858738]])
+ expected_state1_c = np.array([[1.17500675, 1.26892781]])
+ expected_state1_h = np.array([[0.70230919, 0.72581059]])
actual_h = res[0]
actual_state0_c = res[1][0].c
@@ -1215,21 +1202,22 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(0.5)) as vs:
x = array_ops.zeros(
- [1, 3]) # Test BasicLSTMCell with input_size != num_units.
+ [1, 3]) # Test BasicLSTMCell with input_size != num_units.
c = array_ops.zeros([1, 2])
h = array_ops.zeros([1, 2])
state = rnn_cell.LSTMStateTuple(c, h)
cell = contrib_rnn_cell.LayerNormBasicLSTMCell(2, layer_norm=False)
g, out_m = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.array([[1., 1., 1.]]),
- c.name: 0.1 * np.asarray([[0, 1]]),
- h.name: 0.1 * np.asarray([[2, 3]]),
- })
-
- expected_h = np.array([[ 0.64121795, 0.68166804]])
- expected_c = np.array([[ 0.88477188, 0.98103917]])
+ res = sess.run(
+ [g, out_m], {
+ x.name: np.array([[1., 1., 1.]]),
+ c.name: 0.1 * np.asarray([[0, 1]]),
+ h.name: 0.1 * np.asarray([[2, 3]]),
+ })
+
+ expected_h = np.array([[0.64121795, 0.68166804]])
+ expected_c = np.array([[0.88477188, 0.98103917]])
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], expected_h, 1e-5)
self.assertAllClose(res[1].c, expected_c, 1e-5)
@@ -1250,13 +1238,14 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
[contrib_rnn_cell.LayerNormBasicLSTMCell(2) for _ in range(2)])
h, (s0, s1) = cell(x, (state0, state1))
sess.run([variables.global_variables_initializer()])
- res = sess.run([h, s0, s1], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
+ res = sess.run(
+ [h, s0, s1], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
expected_h = np.array([[-0.38079708, 0.38079708]])
expected_h0 = np.array([[-0.38079708, 0.38079708]])
@@ -1344,11 +1333,12 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
g, s = cell(x, state)
sess.run([variables.global_variables_initializer()])
- res = sess.run([g, s], {
- x.name: np.ones([1, 5]),
- c.name: np.ones([1, 5]),
- h.name: np.ones([1, 5]),
- })
+ res = sess.run(
+ [g, s], {
+ x.name: np.ones([1, 5]),
+ c.name: np.ones([1, 5]),
+ h.name: np.ones([1, 5]),
+ })
# Since the returned tensors are of size [1,n]
# get the first component right now.
@@ -1374,35 +1364,35 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertIn(dropped_count, allowed_low)
-def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
- num_layers, max_time, compiled):
+def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth, num_layers,
+ max_time, compiled):
with variable_scope.variable_scope(
"root",
initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
inputs = variable_scope.get_variable(
- "inputs", initializer=random_ops.random_uniform(
+ "inputs",
+ initializer=random_ops.random_uniform(
(max_time, batch_size, input_depth), seed=1))
maybe_xla = lambda c: contrib_rnn_cell.CompiledWrapper(c) if compiled else c
cell = rnn_cell.MultiRNNCell(
[maybe_xla(rnn_cell.LSTMCell(num_units)) for _ in range(num_layers)])
- initial_state = cell.zero_state(
- batch_size=batch_size, dtype=dtypes.float32)
+ initial_state = cell.zero_state(batch_size=batch_size, dtype=dtypes.float32)
outputs, final_state = rnn.dynamic_rnn(
- cell=cell, inputs=inputs, initial_state=initial_state,
- time_major=True)
+ cell=cell, inputs=inputs, initial_state=initial_state, time_major=True)
flat_final_state = nest.flatten(final_state)
trainable_variables = variables.trainable_variables()
outputs_grad = gradients_impl.gradients(
- [outputs],
- trainable_variables + [inputs] + nest.flatten(initial_state))
+ [outputs], trainable_variables + [inputs] + nest.flatten(initial_state))
final_state_grad = gradients_impl.gradients(
flat_final_state,
trainable_variables + [inputs] + nest.flatten(initial_state))
- return {"outputs": outputs,
- "final_state": flat_final_state,
- "outputs_grad": outputs_grad,
- "final_state_grad": final_state_grad}
+ return {
+ "outputs": outputs,
+ "final_state": flat_final_state,
+ "outputs_grad": outputs_grad,
+ "final_state_grad": final_state_grad
+ }
class CompiledWrapperTest(test.TestCase):
@@ -1420,8 +1410,10 @@ class CompiledWrapperTest(test.TestCase):
random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
xla_ops = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=True)
sess.run([variables.global_variables_initializer()])
@@ -1430,8 +1422,10 @@ class CompiledWrapperTest(test.TestCase):
random_seed.set_random_seed(1234)
with self.test_session(graph=ops.Graph()) as sess:
non_xla_ops = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=False)
sess.run([variables.global_variables_initializer()])
@@ -1440,16 +1434,16 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(
non_xla_results["outputs"], xla_results["outputs"], atol=atol)
- for xla_value, non_xla_value in zip(
- xla_results["final_state"], non_xla_results["final_state"]):
+ for xla_value, non_xla_value in zip(xla_results["final_state"],
+ non_xla_results["final_state"]):
self.assertAllClose(xla_value, non_xla_value, atol=atol)
- for xla_g, non_xla_g in zip(
- xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
+ for xla_g, non_xla_g in zip(xla_results["outputs_grad"],
+ non_xla_results["outputs_grad"]):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
- for xla_g, non_xla_g in zip(
- xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
+ for xla_g, non_xla_g in zip(xla_results["final_state_grad"],
+ non_xla_results["final_state_grad"]):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
@@ -1463,19 +1457,20 @@ class CompiledWrapperTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
rnn_cell.MultiRNNCell(
- [rnn_cell.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_bad)
+ [rnn_cell.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_bad)
_, ml = rnn_cell.MultiRNNCell(
- [rnn_cell.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_good)
+ [rnn_cell.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_good)
sess.run([variables.global_variables_initializer()])
- res = sess.run(ml, {
- x.name: np.array([[1., 1.]]),
- m_good[0].name: np.array([[0.1, 0.1]]),
- m_good[1].name: np.array([[0.1, 0.1]])
- })
+ res = sess.run(
+ ml, {
+ x.name: np.array([[1., 1.]]),
+ m_good[0].name: np.array([[0.1, 0.1]]),
+ m_good[1].name: np.array([[0.1, 0.1]])
+ })
# The numbers in results were not calculated, this is just a
# smoke test. However, these numbers should match those of
@@ -1490,24 +1485,20 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
num_layers = 3
max_time = 50
print("benchmarkDynamicRNNWithMultiLSTMCell")
- print("\t" +
- "\t".join(["inter_th", "intra_th",
- "batch_size", "num_units", "input_depth", "device",
- "compiled", "wall_time"]))
+ print("\t" + "\t".join([
+ "inter_th", "intra_th", "batch_size", "num_units", "input_depth",
+ "device", "compiled", "wall_time"
+ ]))
warmup_run = True
- for (threads,
- device,
- num_units,
- batch_size,
- input_depth,
- compiled) in itertools.product(
- [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
- ["cpu", "gpu"],
- [32, 512],
- [1, 32, 256],
- [32, 512],
- [False, True]):
+ for (threads, device, num_units, batch_size, input_depth,
+ compiled) in itertools.product([{
+ "inter": 0,
+ "intra": 0
+ }, {
+ "inter": 1,
+ "intra": 4
+ }], ["cpu", "gpu"], [32, 512], [1, 32, 256], [32, 512], [False, True]):
if threads["inter"] != 0:
# We only care about testing inter/intra op limitations on
# CPU with small batch size, to mimic embedded devices.
@@ -1523,31 +1514,222 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
with session.Session(config=config, graph=ops.Graph()) as sess:
with ops.device("/%s:0" % device):
ops_dict = _create_multi_lstm_cell_ops(
- batch_size=batch_size, num_units=num_units,
- input_depth=input_depth, num_layers=num_layers,
+ batch_size=batch_size,
+ num_units=num_units,
+ input_depth=input_depth,
+ num_layers=num_layers,
max_time=max_time,
compiled=compiled)
sess.run([variables.global_variables_initializer()])
all_ops = nest.flatten(ops_dict.values())
all_ops_group = control_flow_ops.group(*all_ops)
- name_suffix = (
- "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
- "_device_%s_xla_%s" % (
- threads["inter"], threads["intra"],
- batch_size, num_units, input_depth, device, compiled))
+ name_suffix = ("inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
+ "_device_%s_xla_%s" %
+ (threads["inter"], threads["intra"], batch_size,
+ num_units, input_depth, device, compiled))
if warmup_run:
self.run_op_benchmark(
sess, all_ops_group, min_iters=30, name="ignore_warmup")
warmup_run = False
benchmark_results = self.run_op_benchmark(
- sess, all_ops_group, min_iters=50,
+ sess,
+ all_ops_group,
+ min_iters=50,
name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
- print("\t" +
- "\t".join(["%s" % x for x in [
- threads["inter"], threads["intra"],
- batch_size, num_units, input_depth, device, compiled,
- benchmark_results["wall_time"]]]))
+ print("\t" + "\t".join([
+ "%s" % x
+ for x in [
+ threads["inter"], threads["intra"], batch_size, num_units,
+ input_depth, device, compiled, benchmark_results["wall_time"]
+ ]
+ ]))
+
+
+class WeightNormLSTMCellTest(test.TestCase):
+ """Compared cell output with pre-calculated values."""
+
+ def _cell_output(self, cell):
+ """Calculate cell output"""
+
+ with self.test_session() as sess:
+ init = init_ops.constant_initializer(0.5)
+ with variable_scope.variable_scope("root", initializer=init):
+ x = array_ops.zeros([1, 2])
+ c0 = array_ops.zeros([1, 2])
+ h0 = array_ops.zeros([1, 2])
+
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
+
+ xout, sout = cell()(x, state0)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [xout, sout], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ })
+
+ actual_state_c = res[1].c
+ actual_state_h = res[1].h
+
+ return actual_state_c, actual_state_h
+
+ def testBasicCell(self):
+ """Tests cell w/o peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=False, use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937078, 0.74983585]])
+ expected_h = np.array([[0.44923624, 0.49362513]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonbasicCell(self):
+ """Tests cell with peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=False, use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937084, 0.7574988]])
+ expected_h = np.array([[0.4792085, 0.53470564]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testBasicCellWithNorm(self):
+ """Tests cell w/o peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=True, use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.58805949]])
+ expected_h = np.array([[0.32770363, 0.37397948]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonBasicCellWithNorm(self):
+ """Tests cell with peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(
+ 2, norm=True, use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.59587258]])
+ expected_h = np.array([[0.35041603, 0.40873795]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+
+class WeightNormLSTMCellTest(test.TestCase):
+ """Compared cell output with pre-calculated values."""
+
+ def _cell_output(self, cell):
+ """Calculate cell output"""
+
+ with self.test_session() as sess:
+ init = init_ops.constant_initializer(0.5)
+ with variable_scope.variable_scope("root",
+ initializer=init):
+ x = array_ops.zeros([1, 2])
+ c0 = array_ops.zeros([1, 2])
+ h0 = array_ops.zeros([1, 2])
+
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
+
+ xout, sout = cell()(x, state0)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([xout, sout], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ })
+
+ actual_state_c = res[1].c
+ actual_state_h = res[1].h
+
+ return actual_state_c, actual_state_h
+
+ def testBasicCell(self):
+ """Tests cell w/o peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937078, 0.74983585]])
+ expected_h = np.array([[0.44923624, 0.49362513]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonbasicCell(self):
+ """Tests cell with peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937084, 0.7574988]])
+ expected_h = np.array([[0.4792085, 0.53470564]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+
+ def testBasicCellWithNorm(self):
+ """Tests cell w/o peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.58805949]])
+ expected_h = np.array([[0.32770363, 0.37397948]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonBasicCellWithNorm(self):
+ """Tests cell with peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.59587258]])
+ expected_h = np.array([[0.35041603, 0.40873795]])
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index e4667828cd..8adf5dce6e 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
@@ -38,6 +37,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import nn_impl
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@@ -55,16 +55,15 @@ def _get_concat_variable(name, shape, dtype, num_shards):
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
- ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
- concat_variable)
+ ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
return concat_variable
def _get_sharded_variable(name, shape, dtype, num_shards):
"""Get a list of sharded variables with the given dtype."""
if num_shards > shape[0]:
- raise ValueError("Too many shards: shape=%s, num_shards=%d" %
- (shape, num_shards))
+ raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
+ num_shards))
unit_shard_size = int(math.floor(shape[0] / num_shards))
remaining_rows = shape[0] - unit_shard_size * num_shards
@@ -73,8 +72,9 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
current_size = unit_shard_size
if i < remaining_rows:
current_size += 1
- shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
- dtype=dtype))
+ shards.append(
+ vs.get_variable(
+ name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
return shards
@@ -176,9 +176,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
"""
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
- logging.warn(
- "%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
self._num_units = num_units
self._use_peepholes = use_peepholes
self._initializer = initializer
@@ -195,12 +194,14 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
self._norm_shift = norm_shift
if num_proj:
- self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
- if state_is_tuple else num_units + num_proj)
+ self._state_size = (
+ rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
+ if state_is_tuple else num_units + num_proj)
self._output_size = num_proj
else:
- self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)
- if state_is_tuple else 2 * num_units)
+ self._state_size = (
+ rnn_cell_impl.LSTMStateTuple(num_units, num_units)
+ if state_is_tuple else 2 * num_units)
self._output_size = num_units
@property
@@ -250,8 +251,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
concat_w = _get_concat_variable(
- "W", [input_size.value + num_proj, 3 * self._num_units],
- dtype, self._num_unit_shards)
+ "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
+ self._num_unit_shards)
b = vs.get_variable(
"B",
@@ -298,9 +299,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
- concat_w_proj = _get_concat_variable(
- "W_P", [self._num_units, self._num_proj],
- dtype, self._num_proj_shards)
+ concat_w_proj = _get_concat_variable("W_P",
+ [self._num_units, self._num_proj],
+ dtype, self._num_proj_shards)
m = math_ops.matmul(m, concat_w_proj)
if self._proj_clip is not None:
@@ -308,8 +309,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
- new_state = (rnn_cell_impl.LSTMStateTuple(c, m)
- if self._state_is_tuple else array_ops.concat([c, m], 1))
+ new_state = (
+ rnn_cell_impl.LSTMStateTuple(c, m)
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
return m, new_state
@@ -325,10 +327,15 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
It uses peep-hole connections and optional cell clipping.
"""
- def __init__(self, num_units, use_peepholes=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=None,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=1,
reuse=None):
"""Initialize the parameters for an LSTM cell.
@@ -398,7 +405,7 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
concat_w = _get_concat_variable(
- "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
+ "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
dtype, self._num_unit_shards)
b = vs.get_variable(
@@ -417,23 +424,23 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
- self._num_units], dtype)
+ m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
- c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
+ c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])
- m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
+ m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
[-1, self._num_units])
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
- cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
- 1)
+ cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._use_peepholes:
- c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * tanh(j))
+ c = (
+ sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * tanh(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
@@ -471,11 +478,11 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
input_size = input_feat.get_shape().with_rank(2)[-1].value
if input_size is None:
raise ValueError("Cannot infer input_size from static shape inference.")
- num_feats = int((input_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ num_feats = int(
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
freq_inputs = []
for f in range(num_feats):
- cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
+ cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
[-1, self._feature_size])
freq_inputs.append(cur_input)
return freq_inputs
@@ -497,11 +504,16 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
The code uses optional peephole connections, shared_weights and cell clipping.
"""
- def __init__(self, num_units, use_peepholes=False,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
share_time_frequency_weights=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=None,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=None,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
@@ -579,10 +591,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
for freq_index in range(self._num_frequency_blocks[block_index]):
name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
- self._state_tuple_type = collections.namedtuple(
- "GridLSTMStateTuple", state_names.strip(","))
- self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks))
+ self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
+ state_names.strip(","))
+ self._state_size = self._state_tuple_type(*(
+ [num_units, num_units] * self._total_blocks))
else:
self._state_tuple_type = None
self._state_size = num_units * self._total_blocks * 2
@@ -625,7 +637,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
state_out_lst = []
for block in range(len(freq_inputs)):
m_out_lst_current, state_out_lst_current = self._compute(
- freq_inputs[block], block, state, batch_size,
+ freq_inputs[block],
+ block,
+ state,
+ batch_size,
state_is_tuple=self._state_is_tuple)
m_out_lst.extend(m_out_lst_current)
state_out_lst.extend(state_out_lst_current)
@@ -636,7 +651,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
m_out = array_ops.concat(m_out_lst, 1)
return m_out, state_out
- def _compute(self, freq_inputs, block, state, batch_size,
+ def _compute(self,
+ freq_inputs,
+ block,
+ state,
+ batch_size,
state_prefix="state",
state_is_tuple=True):
"""Run the actual computation of one step LSTM.
@@ -665,8 +684,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
concat_w_f = _get_concat_variable(
- "W_f_%d" % block, [actual_input_size + 2 * self._num_units,
- num_gates * self._num_units],
+ "W_f_%d" % block,
+ [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
dtype, self._num_unit_shards)
b_f = vs.get_variable(
"B_f_%d" % block,
@@ -674,10 +693,9 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
initializer=init_ops.zeros_initializer(),
dtype=dtype)
if not self._share_time_frequency_weights:
- concat_w_t = _get_concat_variable(
- "W_t_%d" % block, [actual_input_size + 2 * self._num_units,
- num_gates * self._num_units],
- dtype, self._num_unit_shards)
+ concat_w_t = _get_concat_variable("W_t_%d" % block, [
+ actual_input_size + 2 * self._num_units, num_gates * self._num_units
+ ], dtype, self._num_unit_shards)
b_t = vs.get_variable(
"B_t_%d" % block,
shape=[num_gates * self._num_units],
@@ -690,7 +708,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
w_f_diag_freqf = vs.get_variable(
"W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_f_diag_freqt = vs.get_variable(
- "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
+ "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqf = vs.get_variable(
"W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
w_i_diag_freqt = vs.get_variable(
@@ -724,8 +742,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
m_prev_time = getattr(state, name_prefix + "_m")
else:
c_prev_time = array_ops.slice(
- state, [0, 2 * freq_index * self._num_units],
- [-1, self._num_units])
+ state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
m_prev_time = array_ops.slice(
state, [0, (2 * freq_index + 1) * self._num_units],
[-1, self._num_units])
@@ -735,8 +752,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
[freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
# F-LSTM
- lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
- concat_w_f), b_f)
+ lstm_matrix_freq = nn_ops.bias_add(
+ math_ops.matmul(cell_inputs, concat_w_f), b_f)
if self._couple_input_forget_gates:
i_freq, j_freq, o_freq = array_ops.split(
value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
@@ -751,8 +768,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
f_time = f_freq
o_time = o_freq
else:
- lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
- concat_w_t), b_t)
+ lstm_matrix_time = nn_ops.bias_add(
+ math_ops.matmul(cell_inputs, concat_w_t), b_t)
if self._couple_input_forget_gates:
i_time, j_time, o_time = array_ops.split(
value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
@@ -764,8 +781,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# F-LSTM c_freq
# input gate activations
if self._use_peepholes:
- i_freq_g = sigmoid(i_freq +
- w_i_diag_freqf * c_prev_freq +
+ i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
w_i_diag_freqt * c_prev_time)
else:
i_freq_g = sigmoid(i_freq)
@@ -774,9 +790,8 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
f_freq_g = 1.0 - i_freq_g
else:
if self._use_peepholes:
- f_freq_g = sigmoid(f_freq + self._forget_bias +
- w_f_diag_freqf * c_prev_freq +
- w_f_diag_freqt * c_prev_time)
+ f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
else:
f_freq_g = sigmoid(f_freq + self._forget_bias)
# cell state
@@ -791,12 +806,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# input gate activations
if self._use_peepholes:
if self._share_time_frequency_weights:
- i_time_g = sigmoid(i_time +
- w_i_diag_freqf * c_prev_freq +
+ i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
w_i_diag_freqt * c_prev_time)
else:
- i_time_g = sigmoid(i_time +
- w_i_diag_timef * c_prev_freq +
+ i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
w_i_diag_timet * c_prev_time)
else:
i_time_g = sigmoid(i_time)
@@ -806,13 +819,11 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
else:
if self._use_peepholes:
if self._share_time_frequency_weights:
- f_time_g = sigmoid(f_time + self._forget_bias +
- w_f_diag_freqf * c_prev_freq +
- w_f_diag_freqt * c_prev_time)
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
+ c_prev_freq + w_f_diag_freqt * c_prev_time)
else:
- f_time_g = sigmoid(f_time + self._forget_bias +
- w_f_diag_timef * c_prev_freq +
- w_f_diag_timet * c_prev_time)
+ f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
+ c_prev_freq + w_f_diag_timet * c_prev_time)
else:
f_time_g = sigmoid(f_time + self._forget_bias)
# cell state
@@ -825,8 +836,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# F-LSTM m_freq
if self._use_peepholes:
- m_freq = sigmoid(o_freq +
- w_o_diag_freqf * c_freq +
+ m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
w_o_diag_freqt * c_time) * tanh(c_freq)
else:
m_freq = sigmoid(o_freq) * tanh(c_freq)
@@ -834,12 +844,10 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
# T-LSTM m_time
if self._use_peepholes:
if self._share_time_frequency_weights:
- m_time = sigmoid(o_time +
- w_o_diag_freqf * c_freq +
+ m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
w_o_diag_freqt * c_time) * tanh(c_time)
else:
- m_time = sigmoid(o_time +
- w_o_diag_timef * c_freq +
+ m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
w_o_diag_timet * c_time) * tanh(c_time)
else:
m_time = sigmoid(o_time) * tanh(c_time)
@@ -878,16 +886,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
raise ValueError("Cannot infer input_size from static shape inference.")
if slice_offset > 0:
# Padding to the end
- inputs = array_ops.pad(
- input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2],
- dtype=dtypes.int32),
- "CONSTANT")
+ inputs = array_ops.pad(input_feat,
+ array_ops.constant(
+ [0, 0, 0, slice_offset],
+ shape=[2, 2],
+ dtype=dtypes.int32), "CONSTANT")
elif slice_offset < 0:
# Padding to the front
- inputs = array_ops.pad(
- input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2],
- dtype=dtypes.int32),
- "CONSTANT")
+ inputs = array_ops.pad(input_feat,
+ array_ops.constant(
+ [0, 0, -slice_offset, 0],
+ shape=[2, 2],
+ dtype=dtypes.int32), "CONSTANT")
slice_offset = 0
else:
inputs = input_feat
@@ -897,13 +907,13 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
raise ValueError("Length of num_frequency_blocks"
" is not 1, but instead is %d",
len(self._num_frequency_blocks))
- num_feats = int((input_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ num_feats = int(
+ (input_size - self._feature_size) / (self._frequency_skip)) + 1
if num_feats != self._num_frequency_blocks[0]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
- " check the input size and filter config are correct." % (
- self._num_frequency_blocks[0], num_feats))
+ " check the input size and filter config are correct." %
+ (self._num_frequency_blocks[0], num_feats))
block_inputs = []
for f in range(num_feats):
cur_input = array_ops.slice(
@@ -926,18 +936,18 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
start_index = self._start_freqindex_list[b]
end_index = self._end_freqindex_list[b]
cur_size = end_index - start_index
- block_feats = int((cur_size - self._feature_size) / (
- self._frequency_skip)) + 1
+ block_feats = int(
+ (cur_size - self._feature_size) / (self._frequency_skip)) + 1
if block_feats != self._num_frequency_blocks[b]:
raise ValueError(
"Invalid num_frequency_blocks, requires %d but gets %d, please"
- " check the input size and filter config are correct." % (
- self._num_frequency_blocks[b], block_feats))
+ " check the input size and filter config are correct." %
+ (self._num_frequency_blocks[b], block_feats))
block_inputs = []
for f in range(block_feats):
cur_input = array_ops.slice(
- inputs, [0, start_index + slice_offset + f *
- self._frequency_skip],
+ inputs,
+ [0, start_index + slice_offset + f * self._frequency_skip],
[-1, self._feature_size])
block_inputs.append(cur_input)
freq_inputs.append(block_inputs)
@@ -953,11 +963,16 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
The current implementation uses different weights for the two directions.
"""
- def __init__(self, num_units, use_peepholes=False,
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
share_time_frequency_weights=False,
- cell_clip=None, initializer=None,
- num_unit_shards=1, forget_bias=1.0,
- feature_size=None, frequency_skip=None,
+ cell_clip=None,
+ initializer=None,
+ num_unit_shards=1,
+ forget_bias=1.0,
+ feature_size=None,
+ frequency_skip=None,
num_frequency_blocks=None,
start_freqindex_list=None,
end_freqindex_list=None,
@@ -1016,8 +1031,8 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
self._state_tuple_type = collections.namedtuple(
"BidirectionalGridLSTMStateTuple", state_names.strip(","))
- self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks * 2))
+ self._state_size = self._state_tuple_type(*(
+ [num_units, num_units] * self._total_blocks * 2))
self._output_size = 2 * num_units * self._total_blocks * 2
def call(self, inputs, state):
@@ -1051,8 +1066,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
- fwd_inputs[block], block, state, batch_size,
- state_prefix="fwd_state", state_is_tuple=True)
+ fwd_inputs[block],
+ block,
+ state,
+ batch_size,
+ state_prefix="fwd_state",
+ state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
@@ -1063,8 +1082,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
- bwd_inputs_reverse, block, state, batch_size,
- state_prefix="bwd_state", state_is_tuple=True)
+ bwd_inputs_reverse,
+ block,
+ state,
+ batch_size,
+ state_prefix="bwd_state",
+ state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
@@ -1075,6 +1098,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
# pylint: disable=protected-access
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
+
# pylint: enable=protected-access
@@ -1084,8 +1108,14 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
Implementation based on https://arxiv.org/abs/1409.0473.
"""
- def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
- input_size=None, state_is_tuple=True, reuse=None):
+ def __init__(self,
+ cell,
+ attn_length,
+ attn_size=None,
+ attn_vec_size=None,
+ input_size=None,
+ state_is_tuple=True,
+ reuse=None):
"""Create a cell with attention.
Args:
@@ -1115,16 +1145,15 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("The parameter cell is not RNNCell.")
if nest.is_sequence(cell.state_size) and not state_is_tuple:
- raise ValueError("Cell returns tuple of states, but the flag "
- "state_is_tuple is not set. State size is: %s"
- % str(cell.state_size))
+ raise ValueError(
+ "Cell returns tuple of states, but the flag "
+ "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
if attn_length <= 0:
- raise ValueError("attn_length should be greater than zero, got %s"
- % str(attn_length))
+ raise ValueError(
+ "attn_length should be greater than zero, got %s" % str(attn_length))
if not state_is_tuple:
- logging.warn(
- "%s: Using a concatenated state is slower and will soon be "
- "deprecated. Use state_is_tuple=True.", self)
+ logging.warn("%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
if attn_size is None:
attn_size = cell.output_size
if attn_vec_size is None:
@@ -1160,8 +1189,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
else:
states = state
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
- attns = array_ops.slice(
- states, [0, self._cell.state_size], [-1, self._attn_size])
+ attns = array_ops.slice(states, [0, self._cell.state_size],
+ [-1, self._attn_size])
attn_states = array_ops.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
@@ -1199,8 +1228,8 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell):
tanh = math_ops.tanh
with vs.variable_scope("attention"):
- k = vs.get_variable(
- "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
+ k = vs.get_variable("attn_w",
+ [1, 1, self._attn_size, self._attn_vec_size])
v = vs.get_variable("attn_v", [self._attn_vec_size])
hidden = array_ops.reshape(attn_states,
[-1, self._attn_length, 1, self._attn_size])
@@ -1227,7 +1256,8 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
https://arxiv.org/abs/1505.00387
"""
- def __init__(self, cell,
+ def __init__(self,
+ cell,
couple_carry_transform_gates=True,
carry_bias_init=1.0):
"""Constructs a `HighwayWrapper` for `cell`.
@@ -1259,8 +1289,7 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
carry_weight = vs.get_variable("carry_w", [input_size, input_size])
carry_bias = vs.get_variable(
"carry_b", [input_size],
- initializer=init_ops.constant_initializer(
- self._carry_bias_init))
+ initializer=init_ops.constant_initializer(self._carry_bias_init))
carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
if self._couple_carry_transform_gates:
transform = 1 - carry
@@ -1269,11 +1298,9 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
[input_size, input_size])
transform_bias = vs.get_variable(
"transform_b", [input_size],
- initializer=init_ops.constant_initializer(
- -self._carry_bias_init))
- transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
- transform_weight,
- transform_bias))
+ initializer=init_ops.constant_initializer(-self._carry_bias_init))
+ transform = math_ops.sigmoid(
+ nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
return inp * carry + out * transform
def __call__(self, inputs, state, scope=None):
@@ -1293,9 +1320,11 @@ class HighwayWrapper(rnn_cell_impl.RNNCell):
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
nest.assert_same_structure(inputs, outputs)
+
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
+
nest.map_structure(assert_shape_match, inputs, outputs)
res_outputs = nest.map_structure(self._highway, inputs, outputs)
return (res_outputs, new_state)
@@ -1321,10 +1350,16 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
"""
- def __init__(self, num_units, forget_bias=1.0,
- input_size=None, activation=math_ops.tanh,
- layer_norm=True, norm_gain=1.0, norm_shift=0.0,
- dropout_keep_prob=1.0, dropout_prob_seed=None,
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ input_size=None,
+ activation=math_ops.tanh,
+ layer_norm=True,
+ norm_gain=1.0,
+ norm_shift=0.0,
+ dropout_keep_prob=1.0,
+ dropout_prob_seed=None,
reuse=None):
"""Initializes the basic LSTM cell.
@@ -1409,8 +1444,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
- new_c = (c * math_ops.sigmoid(f + self._forget_bias)
- + math_ops.sigmoid(i) * g)
+ new_c = (
+ c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state", dtype=dtype)
new_h = self._activation(new_c) * math_ops.sigmoid(o)
@@ -1432,8 +1467,7 @@ class NASCell(rnn_cell_impl.RNNCell):
The class uses an optional projection layer.
"""
- def __init__(self, num_units, num_proj=None,
- use_biases=False, reuse=None):
+ def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
"""Initialize the parameters for a NAS cell.
Args:
@@ -1503,12 +1537,10 @@ class NASCell(rnn_cell_impl.RNNCell):
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
# Variables for the NAS cell. W_m is all matrices multiplying the
# hiddenstate and W_inputs is all matrices multiplying the inputs.
- concat_w_m = vs.get_variable(
- "recurrent_kernel", [num_proj, 8 * self._num_units],
- dtype)
+ concat_w_m = vs.get_variable("recurrent_kernel",
+ [num_proj, 8 * self._num_units], dtype)
concat_w_inputs = vs.get_variable(
- "kernel", [input_size.value, 8 * self._num_units],
- dtype)
+ "kernel", [input_size.value, 8 * self._num_units], dtype)
m_matrix = math_ops.matmul(m_prev, concat_w_m)
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
@@ -1523,10 +1555,10 @@ class NASCell(rnn_cell_impl.RNNCell):
# The NAS cell branches into 8 different splits for both the hiddenstate
# and the input
- m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
- value=m_matrix)
- inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
- value=inputs_matrix)
+ m_matrix_splits = array_ops.split(
+ axis=1, num_or_size_splits=8, value=m_matrix)
+ inputs_matrix_splits = array_ops.split(
+ axis=1, num_or_size_splits=8, value=inputs_matrix)
# First layer
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
@@ -1558,9 +1590,8 @@ class NASCell(rnn_cell_impl.RNNCell):
# Projection layer if specified
if self._num_proj is not None:
- concat_w_proj = vs.get_variable(
- "projection_weights", [self._num_units, self._num_proj],
- dtype)
+ concat_w_proj = vs.get_variable("projection_weights",
+ [self._num_units, self._num_proj], dtype)
new_m = math_ops.matmul(new_m, concat_w_proj)
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
@@ -1583,8 +1614,12 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
"Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
"""
- def __init__(self, num_units, initializer=None, forget_bias=1.0,
- activation=math_ops.tanh, reuse=None):
+ def __init__(self,
+ num_units,
+ initializer=None,
+ forget_bias=1.0,
+ activation=math_ops.tanh,
+ reuse=None):
"""Initialize the parameters for an UGRNN cell.
Args:
@@ -1639,8 +1674,8 @@ class UGRNNCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(vs.get_variable_scope(),
- initializer=self._initializer):
+ with vs.variable_scope(
+ vs.get_variable_scope(), initializer=self._initializer):
cell_inputs = array_ops.concat([inputs, state], 1)
if self._linear is None:
self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
@@ -1680,9 +1715,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
RNNs so it may not achieve best performance with depth 1.
"""
- def __init__(self, num_units, num_in_proj=None,
- initializer=None, forget_bias=1.0,
- y_activation=nn_ops.relu, reuse=None):
+ def __init__(self,
+ num_units,
+ num_in_proj=None,
+ initializer=None,
+ forget_bias=1.0,
+ y_activation=nn_ops.relu,
+ reuse=None):
"""Initialize the parameters for an +RNN cell.
Args:
@@ -1746,8 +1785,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(vs.get_variable_scope(),
- initializer=self._initializer):
+ with vs.variable_scope(
+ vs.get_variable_scope(), initializer=self._initializer):
# read-in projections (should be used for first layer in deep +RNN
# to transform size of inputs from I --> N)
if input_size.value != self._num_units:
@@ -1764,13 +1803,13 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell):
n_dim = i_dim = self._num_units
cell_inputs = array_ops.concat([inputs, state], 1)
if self._linear2 is None:
- self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True)
+ self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
rnn_matrix = self._linear2(cell_inputs)
- gh_act = rnn_matrix[:, :n_dim] # b x n
- h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n
- gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim] # b x i
- y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim] # b x i
+ gh_act = rnn_matrix[:, :n_dim] # b x n
+ h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n
+ gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i
+ y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i
h = tanh(h_act)
y = self._y_activation(y_act)
@@ -1816,6 +1855,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
if self._compile_stateful:
compile_ops = True
else:
+
def compile_ops(node_def):
global _REGISTERED_OPS
if _REGISTERED_OPS is None:
@@ -1826,10 +1866,7 @@ class CompiledWrapper(rnn_cell_impl.RNNCell):
return self._cell(inputs, state, scope=scope)
-def _random_exp_initializer(minval,
- maxval,
- seed=None,
- dtype=dtypes.float32):
+def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
"""Returns an exponential distribution initializer.
Args:
@@ -1848,10 +1885,7 @@ def _random_exp_initializer(minval,
del partition_info # Unused.
return math_ops.exp(
random_ops.random_uniform(
- shape,
- math_ops.log(minval),
- math_ops.log(maxval),
- dtype,
+ shape, math_ops.log(minval), math_ops.log(maxval), dtype,
seed=seed))
return _initializer
@@ -1955,8 +1989,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
if self._linear1 is None:
self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
- mask_gates = math_ops.sigmoid(
- self._linear1(in_mask_gates))
+ mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
[input_gate, forget_gate] = array_ops.split(
axis=1, num_or_size_splits=2, value=mask_gates)
@@ -1980,12 +2013,12 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
period = vs.get_variable(
"period", [self._num_units],
- initializer=_random_exp_initializer(
- self._period_init_min, self._period_init_max))
+ initializer=_random_exp_initializer(self._period_init_min,
+ self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
- initializer=init_ops.random_uniform_initializer(
- 0., period.initial_value))
+ initializer=init_ops.random_uniform_initializer(0.,
+ period.initial_value))
ratio_on = vs.get_variable(
"ratio_on", [self._num_units],
initializer=init_ops.constant_initializer(self._ratio_on),
@@ -2007,6 +2040,7 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell):
return new_h, new_state
+
class ConvLSTMCell(rnn_cell_impl.RNNCell):
"""Convolutional LSTM recurrent network cell.
@@ -2040,7 +2074,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
"""
super(ConvLSTMCell, self).__init__(name=name)
- if conv_ndims != len(input_shape)-1:
+ if conv_ndims != len(input_shape) - 1:
raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
input_shape, conv_ndims))
@@ -2059,8 +2093,8 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
state_size = tensor_shape.TensorShape(
self._input_shape[:-1] + [self._output_channels])
self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
- self._output_size = tensor_shape.TensorShape(self._input_shape[:-1]
- + [self._total_output_channels])
+ self._output_size = tensor_shape.TensorShape(
+ self._input_shape[:-1] + [self._total_output_channels])
@property
def output_size(self):
@@ -2072,13 +2106,10 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
def call(self, inputs, state, scope=None):
cell, hidden = state
- new_hidden = _conv([inputs, hidden],
- self._kernel_shape,
- 4*self._output_channels,
- self._use_bias)
- gates = array_ops.split(value=new_hidden,
- num_or_size_splits=4,
- axis=self._conv_ndims+1)
+ new_hidden = _conv([inputs, hidden], self._kernel_shape,
+ 4 * self._output_channels, self._use_bias)
+ gates = array_ops.split(
+ value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
input_gate, new_input, forget_gate, output_gate = gates
new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
@@ -2090,29 +2121,35 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
return output, new_state
+
class Conv1DLSTMCell(ConvLSTMCell):
"""1D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_1d_lstm_cell", **kwargs):
"""Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
+
class Conv2DLSTMCell(ConvLSTMCell):
"""2D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_2d_lstm_cell", **kwargs):
"""Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
+
class Conv3DLSTMCell(ConvLSTMCell):
"""3D Convolutional LSTM recurrent network cell.
https://arxiv.org/pdf/1506.04214v1.pdf
"""
+
def __init__(self, name="conv_3d_lstm_cell", **kwargs):
"""Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
@@ -2137,7 +2174,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
shapes = [a.get_shape().as_list() for a in args]
shape_length = len(shapes[0])
for shape in shapes:
- if len(shape) not in [3,4,5]:
+ if len(shape) not in [3, 4, 5]:
raise ValueError("Conv Linear expects 3D, 4D "
"or 5D arguments: %s" % str(shapes))
if len(shape) != len(shapes[0]):
@@ -2148,40 +2185,36 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0):
dtype = [a.dtype for a in args][0]
# determine correct conv operation
- if shape_length == 3:
+ if shape_length == 3:
conv_op = nn_ops.conv1d
strides = 1
elif shape_length == 4:
conv_op = nn_ops.conv2d
- strides = shape_length*[1]
+ strides = shape_length * [1]
elif shape_length == 5:
conv_op = nn_ops.conv3d
- strides = shape_length*[1]
+ strides = shape_length * [1]
# Now the computation.
kernel = vs.get_variable(
- "kernel",
- filter_size + [total_arg_size_depth, num_features],
- dtype=dtype)
+ "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
if len(args) == 1:
- res = conv_op(args[0],
- kernel,
- strides,
- padding='SAME')
+ res = conv_op(args[0], kernel, strides, padding="SAME")
else:
- res = conv_op(array_ops.concat(axis=shape_length-1, values=args),
- kernel,
- strides,
- padding='SAME')
+ res = conv_op(
+ array_ops.concat(axis=shape_length - 1, values=args),
+ kernel,
+ strides,
+ padding="SAME")
if not bias:
return res
bias_term = vs.get_variable(
"biases", [num_features],
dtype=dtype,
- initializer=init_ops.constant_initializer(
- bias_start, dtype=dtype))
+ initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
return res + bias_term
+
class GLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM).
@@ -2193,8 +2226,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
"""
- def __init__(self, num_units, initializer=None, num_proj=None,
- number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh,
+ def __init__(self,
+ num_units,
+ initializer=None,
+ num_proj=None,
+ number_of_groups=1,
+ forget_bias=1.0,
+ activation=math_ops.tanh,
reuse=None):
"""Initialize the parameters of G-LSTM cell.
@@ -2231,11 +2269,15 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
if self._num_proj:
if self._num_proj % self._number_of_groups != 0:
raise ValueError("num_proj must be divisible by number_of_groups")
- self._group_shape = [int(self._num_proj / self._number_of_groups),
- int(self._num_units / self._number_of_groups)]
+ self._group_shape = [
+ int(self._num_proj / self._number_of_groups),
+ int(self._num_units / self._number_of_groups)
+ ]
else:
- self._group_shape = [int(self._num_units / self._number_of_groups),
- int(self._num_units / self._number_of_groups)]
+ self._group_shape = [
+ int(self._num_units / self._number_of_groups),
+ int(self._num_units / self._number_of_groups)
+ ]
if num_proj:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
@@ -2267,10 +2309,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
subset of inputs corresponding to group "group_id",
a Tensor, 2D, [batch x num_units/number_of_groups]
"""
- return array_ops.slice(input_=inputs,
- begin=[0, group_id * group_size],
- size=[self._batch_size, group_size],
- name=("GLSTM_group%d_input_generation" % group_id))
+ return array_ops.slice(
+ input_=inputs,
+ begin=[0, group_id * group_size],
+ size=[self._batch_size, group_size],
+ name=("GLSTM_group%d_input_generation" % group_id))
def call(self, inputs, state):
"""Run one step of G-LSTM.
@@ -2309,10 +2352,13 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
for group_id in range(self._number_of_groups):
with vs.variable_scope("group%d" % group_id):
x_g_id = array_ops.concat(
- [self._get_input_for_group(inputs, group_id,
- self._group_shape[0]),
- self._get_input_for_group(m_prev, group_id,
- self._group_shape[0])], axis=1)
+ [
+ self._get_input_for_group(inputs, group_id,
+ self._group_shape[0]),
+ self._get_input_for_group(m_prev, group_id,
+ self._group_shape[0])
+ ],
+ axis=1)
if self._linear1 is None:
self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False)
R_k = self._linear1(x_g_id) # pylint: disable=invalid-name
@@ -2323,34 +2369,35 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
f_parts.append(f_k)
o_parts.append(o_k)
- bi = vs.get_variable(name="bias_i",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bj = vs.get_variable(name="bias_j",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bf = vs.get_variable(name="bias_f",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
- bo = vs.get_variable(name="bias_o",
- shape=[self._num_units],
- dtype=dtype,
- initializer=
- init_ops.constant_initializer(0.0, dtype=dtype))
+ bi = vs.get_variable(
+ name="bias_i",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bj = vs.get_variable(
+ name="bias_j",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bf = vs.get_variable(
+ name="bias_f",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
+ bo = vs.get_variable(
+ name="bias_o",
+ shape=[self._num_units],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(0.0, dtype=dtype))
i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
- c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
- math_ops.sigmoid(i) * math_ops.tanh(j))
+ c = (
+ math_ops.sigmoid(f + self._forget_bias) * c_prev +
+ math_ops.sigmoid(i) * math_ops.tanh(j))
m = math_ops.sigmoid(o) * self._activation(c)
if self._num_proj is not None:
@@ -2635,10 +2682,12 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
class SRUCell(rnn_cell_impl._LayerRNNCell):
"""SRU, Simple Recurrent Unit
+
Implementation based on
Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
- This variation of RNN cell is characterized by the simplified data dependence
+ This variation of RNN cell is characterized by the simplified data
+ dependence
between hidden states of two consecutive time steps. Traditionally, hidden
states from a cell at time step t-1 needs to be multiplied with a matrix
W_hh before being fed into the ensuing cell at time step t.
@@ -2656,8 +2705,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
will share weights, but to avoid mistakes we require reuse=True in such
cases.
"""
- def __init__(self, num_units,
- activation=None, reuse=None, name=None):
+
+ def __init__(self, num_units, activation=None, reuse=None, name=None):
super(SRUCell, self).__init__(_reuse=reuse, name=name)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@@ -2675,8 +2724,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[1].value is None:
- raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
input_depth = inputs_shape[1].value
@@ -2711,15 +2760,276 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
"""Simple recurrent unit (SRU) with num_units cells."""
U = math_ops.matmul(inputs, self._kernel)
- x_bar, f_intermediate, r_intermediate = array_ops.split(value=U,
- num_or_size_splits=3,
- axis=1)
+ x_bar, f_intermediate, r_intermediate = array_ops.split(
+ value=U, num_or_size_splits=3, axis=1)
- f_r = math_ops.sigmoid(nn_ops.bias_add(array_ops.concat(
- [f_intermediate, r_intermediate], 1), self._bias))
+ f_r = math_ops.sigmoid(
+ nn_ops.bias_add(
+ array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
c = f * state + (1.0 - f) * x_bar
h = r * self._activation(c) + (1.0 - r) * inputs
return h, c
+
+
+class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
+ """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
+
+ The weight-norm implementation is based on:
+ https://arxiv.org/abs/1602.07868
+ Tim Salimans, Diederik P. Kingma.
+ Weight Normalization: A Simple Reparameterization to Accelerate
+ Training of Deep Neural Networks
+
+ The default LSTM implementation based on:
+ http://www.bioinf.jku.at/publications/older/2604.pdf
+ S. Hochreiter and J. Schmidhuber.
+ "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ The class uses optional peephole connections, optional cell clipping
+ and an optional projection layer.
+
+ The optional peephole implementation is based on:
+ https://research.google.com/pubs/archive/43905.pdf
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+ """
+
+ def __init__(self,
+ num_units,
+ norm=True,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_proj=None,
+ proj_clip=None,
+ forget_bias=1,
+ activation=None,
+ reuse=None):
+ """Initialize the parameters of a weight-normalized LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ norm: If `True`, apply normalization to the weight matrices. If False,
+ the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
+ use_peepholes: bool, set `True` to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ forget_bias: Biases of the forget gate are initialized by default to 1
+ in order to reduce the scale of forgetting at the beginning of
+ the training.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+ super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
+
+ self._scope = "wn_lstm_cell"
+ self._num_units = num_units
+ self._norm = norm
+ self._initializer = initializer
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._num_proj = num_proj
+ self._proj_clip = proj_clip
+ self._activation = activation or math_ops.tanh
+ self._forget_bias = forget_bias
+
+ self._weights_variable_name = "kernel"
+ self._bias_variable_name = "bias"
+
+ if num_proj:
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
+ self._output_size = num_proj
+ else:
+ self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
+ self._output_size = num_units
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def _normalize(self, weight, name):
+ """Apply weight normalization.
+
+ Args:
+ weight: a 2D tensor with known number of columns.
+ name: string, variable name for the normalizer.
+ Returns:
+ A tensor with the same shape as `weight`.
+ """
+
+ output_size = weight.get_shape().as_list()[1]
+ g = vs.get_variable(name, [output_size], dtype=weight.dtype)
+ return nn_impl.l2_normalize(weight, dim=0) * g
+
+ def _linear(self,
+ args,
+ output_size,
+ norm,
+ bias,
+ bias_initializer=None,
+ kernel_initializer=None):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_initializer: starting value to initialize the bias
+ (default is all zeros).
+ kernel_initializer: starting value to initialize the weight.
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] equal to
+ sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ if args is None or (nest.is_sequence(args) and not args):
+ raise ValueError("`args` must be specified")
+ if not nest.is_sequence(args):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape() for a in args]
+ for shape in shapes:
+ if shape.ndims != 2:
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
+ if shape[1].value is None:
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
+ "but saw %s" % (shape, shape[1]))
+ else:
+ total_arg_size += shape[1].value
+
+ dtype = [a.dtype for a in args][0]
+
+ # Now the computation.
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope) as outer_scope:
+ weights = vs.get_variable(
+ self._weights_variable_name, [total_arg_size, output_size],
+ dtype=dtype,
+ initializer=kernel_initializer)
+ if norm:
+ wn = []
+ st = 0
+ with ops.control_dependencies(None):
+ for i in range(len(args)):
+ en = st + shapes[i][1].value
+ wn.append(
+ self._normalize(weights[st:en, :], name="norm_{}".format(i)))
+ st = en
+
+ weights = array_ops.concat(wn, axis=0)
+
+ if len(args) == 1:
+ res = math_ops.matmul(args[0], weights)
+ else:
+ res = math_ops.matmul(array_ops.concat(args, 1), weights)
+ if not bias:
+ return res
+
+ with vs.variable_scope(outer_scope) as inner_scope:
+ inner_scope.set_partitioner(None)
+ if bias_initializer is None:
+ bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
+
+ biases = vs.get_variable(
+ self._bias_variable_name, [output_size],
+ dtype=dtype,
+ initializer=bias_initializer)
+
+ return nn_ops.bias_add(res, biases)
+
+ def call(self, inputs, state):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, batch x num_units.
+ state: A tuple of state Tensors, both `2-D`, with column sizes
+ `c_state` and `m_state`.
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ dtype = inputs.dtype
+ num_units = self._num_units
+ sigmoid = math_ops.sigmoid
+ c, h = state
+
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ with vs.variable_scope(self._scope, initializer=self._initializer):
+
+ concat = self._linear(
+ [inputs, h], 4 * num_units, norm=self._norm, bias=True)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
+
+ if self._use_peepholes:
+ w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
+ w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
+ w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
+
+ new_c = (
+ c * sigmoid(f + self._forget_bias + w_f_diag * c) +
+ sigmoid(i + w_i_diag * c) * self._activation(j))
+ else:
+ new_c = (
+ c * sigmoid(f + self._forget_bias) +
+ sigmoid(i) * self._activation(j))
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+ if self._use_peepholes:
+ new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
+ else:
+ new_h = sigmoid(o) * self._activation(new_c)
+
+ if self._num_proj is not None:
+ with vs.variable_scope("projection"):
+ new_h = self._linear(
+ new_h, self._num_proj, norm=self._norm, bias=False)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
+ self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
+ return new_h, new_state
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index c0df224bc8..b732cdd41e 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -15,8 +15,8 @@ limitations under the License.
// Helpers for working with the SignatureDefs of TensorFlow SavedModels.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
+#define TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
#include <string>
#include <utility>
@@ -66,4 +66,4 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
index 693b02dc43..34da8c82cd 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
+#ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
+#define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -38,4 +38,4 @@ struct GatherTree {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
+#endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index d2beac5f31..9265540317 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -46,20 +46,18 @@ class TestGatherTree(test.TestCase):
# create (batch_size, max_time, beam_width) matrix and transpose it
predicted_ids = np.array(
- [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
- [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
+ [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
dtype=np.int32).transpose([1, 0, 2])
parent_ids = np.array(
- [[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
- [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
+ [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 3)
max_sequence_lengths = [3, 3]
- expected_result = np.array(
- [[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
- [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
+ expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
+ [[2, 4, 4], [7, 6, 6],
+ [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids,
@@ -157,8 +155,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
- self.assertAllEqual(next_state_.finished, [[False, False, False],
- [False, False, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[False, False, False], [False, False, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -212,8 +210,8 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
- self.assertAllEqual(next_state_.finished, [[True, False, False],
- [False, True, False]])
+ self.assertAllEqual(next_state_.finished,
+ [[True, False, False], [False, True, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -225,6 +223,100 @@ class TestBeamStep(test.TestCase):
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
+class TestLargeBeamStep(test.TestCase):
+ """Tests large beam step.
+
+ Tests a single step of beam search in such case that beam size is larger than
+ vocabulary size.
+ """
+
+ def setUp(self):
+ super(TestLargeBeamStep, self).setUp()
+ self.batch_size = 2
+ self.beam_width = 8
+ self.vocab_size = 5
+ self.end_token = 0
+ self.length_penalty_weight = 0.6
+
+ def test_step(self):
+
+ def get_probs():
+ """this simulates the initialize method in BeamSearchDecoder."""
+ log_prob_mask = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width,
+ on_value=True,
+ off_value=False,
+ dtype=dtypes.bool)
+
+ log_prob_zeros = array_ops.zeros(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32)
+ log_prob_neg_inf = array_ops.ones(
+ [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
+
+ log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
+ log_prob_neg_inf)
+ return log_probs
+
+ log_probs = get_probs()
+ dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
+
+ # pylint: disable=invalid-name
+ _finished = array_ops.one_hot(
+ array_ops.zeros([self.batch_size], dtype=dtypes.int32),
+ depth=self.beam_width,
+ on_value=False,
+ off_value=True,
+ dtype=dtypes.bool)
+ _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
+ _lengths[:, 0] = 2
+ _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
+
+ beam_state = beam_search_decoder.BeamSearchDecoderState(
+ cell_state=dummy_cell_state,
+ log_probs=log_probs,
+ lengths=_lengths,
+ finished=_finished)
+
+ logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
+ 0.0001)
+ logits_[0, 0, 2] = 1.9
+ logits_[0, 0, 3] = 2.1
+ logits_[0, 1, 3] = 3.1
+ logits_[0, 1, 4] = 0.9
+ logits_[1, 0, 1] = 0.5
+ logits_[1, 1, 2] = 2.7
+ logits_[1, 2, 2] = 10.0
+ logits_[1, 2, 3] = 0.2
+ logits = constant_op.constant(logits_, dtype=dtypes.float32)
+ log_probs = nn_ops.log_softmax(logits)
+
+ outputs, next_beam_state = beam_search_decoder._beam_search_step(
+ time=2,
+ logits=logits,
+ next_cell_state=dummy_cell_state,
+ beam_state=beam_state,
+ batch_size=ops.convert_to_tensor(self.batch_size),
+ beam_width=self.beam_width,
+ end_token=self.end_token,
+ length_penalty_weight=self.length_penalty_weight)
+
+ with self.test_session() as sess:
+ outputs_, next_state_, _, _ = sess.run(
+ [outputs, next_beam_state, beam_state, log_probs])
+
+ self.assertEqual(outputs_.predicted_ids[0, 0], 3)
+ self.assertEqual(outputs_.predicted_ids[0, 1], 2)
+ self.assertEqual(outputs_.predicted_ids[1, 0], 1)
+ neg_inf = -np.Inf
+ self.assertAllEqual(
+ next_state_.log_probs[:, -3:],
+ [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
+ self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
+ self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
+ self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
+
+
class BeamSearchDecoderTest(test.TestCase):
def _testDynamicDecodeRNN(self, time_major, has_attention):
@@ -250,8 +342,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state = cell.zero_state(batch_size, dtypes.float32)
if has_attention:
inputs = array_ops.placeholder_with_default(
- np.random.randn(batch_size, decoder_max_time,
- input_depth).astype(np.float32),
+ np.random.randn(batch_size, decoder_max_time, input_depth).astype(
+ np.float32),
shape=(None, None, input_depth))
tiled_inputs = beam_search_decoder.tile_batch(
inputs, multiplier=beam_width)
@@ -271,8 +363,7 @@ class BeamSearchDecoderTest(test.TestCase):
cell_state = cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
if has_attention:
- cell_state = cell_state.clone(
- cell_state=initial_state)
+ cell_state = cell_state.clone(cell_state=initial_state)
bsd = beam_search_decoder.BeamSearchDecoder(
cell=cell,
embedding=embedding,
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index ebe25ce077..d6184d6109 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import collections
-
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@@ -38,7 +37,6 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
-
__all__ = [
"BeamSearchDecoderOutput",
"BeamSearchDecoderState",
@@ -49,8 +47,8 @@ __all__ = [
class BeamSearchDecoderState(
- collections.namedtuple("BeamSearchDecoderState", ("cell_state", "log_probs",
- "finished", "lengths"))):
+ collections.namedtuple("BeamSearchDecoderState",
+ ("cell_state", "log_probs", "finished", "lengths"))):
pass
@@ -86,11 +84,12 @@ def _tile_batch(t, multiplier):
tiled_static_batch_size = (
t.shape[0].value * multiplier if t.shape[0].value is not None else None)
tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
- tiled = array_ops.reshape(
- tiled, array_ops.concat(([shape_t[0] * multiplier], shape_t[1:]), 0))
+ tiled = array_ops.reshape(tiled,
+ array_ops.concat(
+ ([shape_t[0] * multiplier], shape_t[1:]), 0))
tiled.set_shape(
- tensor_shape.TensorShape(
- [tiled_static_batch_size]).concatenate(t.shape[1:]))
+ tensor_shape.TensorShape([tiled_static_batch_size]).concatenate(
+ t.shape[1:]))
return tiled
@@ -198,8 +197,8 @@ class BeamSearchDecoder(decoder.Decoder):
"""
if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access
raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
- if (output_layer is not None
- and not isinstance(output_layer, layers_base.Layer)):
+ if (output_layer is not None and
+ not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
@@ -224,13 +223,17 @@ class BeamSearchDecoder(decoder.Decoder):
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
self._initial_cell_state = nest.map_structure(
- self._maybe_split_batch_beams,
- initial_state, self._cell.state_size)
+ self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
- self._finished = array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.bool)
+
+ self._finished = array_ops.one_hot(
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width,
+ on_value=False,
+ off_value=True,
+ dtype=dtypes.bool)
@property
def batch_size(self):
@@ -248,8 +251,7 @@ class BeamSearchDecoder(decoder.Decoder):
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
- lambda s: tensor_shape.TensorShape([None]).concatenate(s),
- size)
+ lambda s: tensor_shape.TensorShape([None]).concatenate(s), size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@@ -298,11 +300,16 @@ class BeamSearchDecoder(decoder.Decoder):
"""
finished, start_inputs = self._finished, self._start_inputs
+ log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz)
+ array_ops.zeros([self._batch_size], dtype=dtypes.int32),
+ depth=self._beam_width,
+ on_value=0.0,
+ off_value=-np.Inf,
+ dtype=nest.flatten(self._initial_cell_state)[0].dtype)
+
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
- log_probs=array_ops.zeros(
- [self._batch_size, self._beam_width],
- dtype=nest.flatten(self._initial_cell_state)[0].dtype),
+ log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
[self._batch_size, self._beam_width], dtype=dtypes.int64))
@@ -359,11 +366,12 @@ class BeamSearchDecoder(decoder.Decoder):
t_shape = array_ops.shape(t)
static_batch_size = tensor_util.constant_value(self._batch_size)
batch_size_beam_width = (
- None if static_batch_size is None
- else static_batch_size * self._beam_width)
+ None
+ if static_batch_size is None else static_batch_size * self._beam_width)
reshaped_t = array_ops.reshape(
- t, array_ops.concat(
- ([self._batch_size * self._beam_width], t_shape[2:]), 0))
+ t,
+ array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]),
+ 0))
reshaped_t.set_shape(
(tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
return reshaped_t
@@ -392,8 +400,9 @@ class BeamSearchDecoder(decoder.Decoder):
s = tensor_shape.TensorShape(s)
t_shape = array_ops.shape(t)
reshaped_t = array_ops.reshape(
- t, array_ops.concat(
- ([self._batch_size, self._beam_width], t_shape[1:]), 0))
+ t,
+ array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]),
+ 0))
static_batch_size = tensor_util.constant_value(self._batch_size)
expected_reshaped_shape = tensor_shape.TensorShape(
[static_batch_size, self._beam_width]).concatenate(s)
@@ -403,8 +412,8 @@ class BeamSearchDecoder(decoder.Decoder):
"We expected it to have shape "
"(batch_size, beam_width, depth) == %s. Perhaps you "
"forgot to create a zero_state with "
- "batch_size=encoder_batch_size * beam_width?"
- % (reshaped_t.shape, expected_reshaped_shape))
+ "batch_size=encoder_batch_size * beam_width?" %
+ (reshaped_t.shape, expected_reshaped_shape))
reshaped_t.set_shape(expected_reshaped_shape)
return reshaped_t
@@ -476,15 +485,13 @@ class BeamSearchDecoder(decoder.Decoder):
cell_state = state.cell_state
inputs = nest.map_structure(
lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
- cell_state = nest.map_structure(
- self._maybe_merge_batch_beams,
- cell_state, self._cell.state_size)
+ cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
+ self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs = nest.map_structure(
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
next_cell_state = nest.map_structure(
- self._maybe_split_batch_beams,
- next_cell_state, self._cell.state_size)
+ self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
@@ -547,7 +554,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
lengths_to_add = array_ops.one_hot(
indices=array_ops.fill([batch_size, beam_width], end_token),
depth=vocab_size,
- on_value=np.int64(0), off_value=np.int64(1),
+ on_value=np.int64(0),
+ off_value=np.int64(1),
dtype=dtypes.int64)
add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
@@ -563,18 +571,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
scores_shape = array_ops.shape(scores)
- scores_flat = control_flow_ops.cond(
- time > 0,
- lambda: array_ops.reshape(scores, [batch_size, -1]),
- lambda: scores[:, 0])
- num_available_beam = control_flow_ops.cond(
- time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
- lambda: math_ops.reduce_prod(scores_shape[2:]))
+ scores_flat = array_ops.reshape(scores, [batch_size, -1])
# Pick the next beams according to the specified successors function
- next_beam_size = math_ops.minimum(
- ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"),
- num_available_beam)
+ next_beam_size = ops.convert_to_tensor(
+ beam_width, dtype=dtypes.int32, name="beam_width")
next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
next_beam_scores.set_shape([static_batch_size, beam_width])
@@ -593,11 +594,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# name="next_beam_word_ids")
# would be a lot cleaner but for reasons unclear, that hides the results of
# the op which prevents capturing it with tfdbg debug ops.
- raw_next_word_ids = math_ops.mod(word_indices, vocab_size,
- name="next_beam_word_ids")
+ raw_next_word_ids = math_ops.mod(
+ word_indices, vocab_size, name="next_beam_word_ids")
next_word_ids = math_ops.to_int32(raw_next_word_ids)
- next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
- name="next_beam_parent_ids")
+ next_beam_ids = math_ops.to_int32(
+ word_indices / vocab_size, name="next_beam_parent_ids")
# Append new ids to current predictions
previously_finished = _tensor_gather_helper(
@@ -606,9 +607,10 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
batch_size=batch_size,
range_size=beam_width,
gather_shape=[-1])
- next_finished = math_ops.logical_or(previously_finished,
- math_ops.equal(next_word_ids, end_token),
- name="next_beam_finished")
+ next_finished = math_ops.logical_or(
+ previously_finished,
+ math_ops.equal(next_word_ids, end_token),
+ name="next_beam_finished")
# Calculate the length of the next predictions.
# 1. Finished beams remain unchanged.
@@ -769,8 +771,12 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
return gather_from
-def _tensor_gather_helper(gather_indices, gather_from, batch_size,
- range_size, gather_shape, name=None):
+def _tensor_gather_helper(gather_indices,
+ gather_from,
+ batch_size,
+ range_size,
+ gather_shape,
+ name=None):
"""Helper for gathering the right indices from the tensor.
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
@@ -801,9 +807,9 @@ def _tensor_gather_helper(gather_indices, gather_from, batch_size,
array_ops.reshape(gather_from, gather_shape), gather_indices)
final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
static_batch_size = tensor_util.constant_value(batch_size)
- final_static_shape = (tensor_shape.TensorShape([static_batch_size])
- .concatenate(
- gather_from.shape[1:1 + len(gather_shape)]))
+ final_static_shape = (
+ tensor_shape.TensorShape([static_batch_size]).concatenate(
+ gather_from.shape[1:1 + len(gather_shape)]))
output = array_ops.reshape(output, final_shape, name="output")
output.set_shape(final_static_shape)
return output
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index ef3722ee41..6d8f786223 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -73,6 +73,14 @@ class Helper(object):
raise NotImplementedError("batch_size has not been implemented")
@abc.abstractproperty
+ def input_shape(self):
+ """Shape of each input element in batch.
+
+ Returns a `TensorShape`.
+ """
+ raise NotImplementedError("input_shape has not been implemented")
+
+ @abc.abstractproperty
def sample_ids_shape(self):
"""Shape of tensor returned by `sample`, excluding the batch dimension.
@@ -127,6 +135,7 @@ class CustomHelper(Helper):
self._sample_fn = sample_fn
self._next_inputs_fn = next_inputs_fn
self._batch_size = None
+ self._input_shape = None
self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
@@ -149,6 +158,8 @@ class CustomHelper(Helper):
(finished, next_inputs) = self._initialize_fn()
if self._batch_size is None:
self._batch_size = array_ops.size(finished)
+ if self._input_shape is None:
+ self._input_shape = next_inputs.shape[1:]
return (finished, next_inputs)
def sample(self, time, outputs, state, name=None):
@@ -184,6 +195,7 @@ class TrainingHelper(Helper):
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
+ self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -199,12 +211,17 @@ class TrainingHelper(Helper):
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
+ self._input_shape = inputs.shape[2:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -212,6 +229,14 @@ class TrainingHelper(Helper):
def sample_ids_dtype(self):
return dtypes.int32
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def sequence_length(self):
+ return self._sequence_length
+
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
@@ -516,12 +541,17 @@ class GreedyEmbeddingHelper(Helper):
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
+ self._input_shape = self._start_inputs.shape[1:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -632,6 +662,8 @@ class InferenceHelper(Helper):
self._sample_dtype = sample_dtype
self._next_inputs_fn = next_inputs_fn
self._batch_size = array_ops.shape(start_inputs)[0]
+ self._input_shape = start_inputs.shape[1:]
+
self._start_inputs = ops.convert_to_tensor(
start_inputs, name="start_inputs")
@@ -640,6 +672,10 @@ class InferenceHelper(Helper):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return self._sample_shape
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index e24efa0de1..4628b6ab1b 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -15,8 +15,8 @@ limitations under the License.
// Shim for systems that need to load both SessionBundle and
// SavedModelBundle interchangeably during migration to SavedModel.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
+#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
+#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
#include <memory>
@@ -67,4 +67,4 @@ Status LoadSessionBundleOrSavedModelBundle(
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
+#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_BUNDLE_SHIM_H_
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py
index 062c9cc680..3149875e41 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.py
+++ b/tensorflow/contrib/session_bundle/bundle_shim.py
@@ -82,7 +82,7 @@ def _convert_default_signature_to_signature_def(signatures):
"""
default_signature = signatures.default_signature
signature_def = meta_graph_pb2.SignatureDef()
- if default_signature.WhichOneof("type") == "regression_signature":
+ if default_signature.WhichOneof("type") == legacy_constants.REGRESSION_SIGNATURE:
regression_signature = default_signature.regression_signature
signature_def.method_name = signature_constants.REGRESS_METHOD_NAME
_add_input_to_signature_def(regression_signature.input.tensor_name,
@@ -91,7 +91,7 @@ def _convert_default_signature_to_signature_def(signatures):
_add_output_to_signature_def(regression_signature.output.tensor_name,
signature_constants.REGRESS_OUTPUTS,
signature_def)
- elif default_signature.WhichOneof("type") == "classification_signature":
+ elif default_signature.WhichOneof("type") == legacy_constants.CLASSIFICATION_SIGNATURE:
classification_signature = default_signature.classification_signature
signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME
_add_input_to_signature_def(classification_signature.input.tensor_name,
@@ -132,8 +132,8 @@ def _convert_named_signatures_to_signature_def(signatures):
signature_constants.PREDICT_OUTPUTS]
# TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once
# it is submitted.
- if (input_signature.WhichOneof("type") != "generic_signature" or
- output_signature.WhichOneof("type") != "generic_signature"):
+ if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE or
+ output_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE):
raise RuntimeError("Named input and output signatures can only be "
"up-converted if they are generic signature. "
"Input signature type is %s, output signature type is "
diff --git a/tensorflow/contrib/session_bundle/constants.py b/tensorflow/contrib/session_bundle/constants.py
index 6ced73241a..e833baee79 100644
--- a/tensorflow/contrib/session_bundle/constants.py
+++ b/tensorflow/contrib/session_bundle/constants.py
@@ -32,3 +32,6 @@ INIT_OP_KEY = "serving_init_op"
SIGNATURES_KEY = "serving_signatures"
ASSETS_KEY = "serving_assets"
GRAPH_KEY = "serving_graph"
+REGRESSION_SIGNATURE = "regression_signature"
+CLASSIFICATION_SIGNATURE = "classification_signature"
+GENERIC_SIGNATURE = "generic_signature"
diff --git a/tensorflow/contrib/session_bundle/session_bundle.h b/tensorflow/contrib/session_bundle/session_bundle.h
index 2ff258411d..b2be46efa6 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.h
+++ b/tensorflow/contrib/session_bundle/session_bundle.h
@@ -15,8 +15,8 @@ limitations under the License.
// Low-level functionality for setting up a inference Session.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
+#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
+#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
#include <memory>
@@ -82,4 +82,4 @@ bool IsPossibleExportDirectory(const StringPiece export_dir);
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
+#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
diff --git a/tensorflow/contrib/session_bundle/signature.h b/tensorflow/contrib/session_bundle/signature.h
index 0049bea008..4ef1277cec 100644
--- a/tensorflow/contrib/session_bundle/signature.h
+++ b/tensorflow/contrib/session_bundle/signature.h
@@ -15,8 +15,8 @@ limitations under the License.
// Helpers for working with TensorFlow exports and their signatures.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
+#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
+#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
#include <string>
#include <utility>
@@ -121,4 +121,4 @@ Status BindGenericNames(const GenericSignature& signature,
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
+#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
diff --git a/tensorflow/contrib/session_bundle/test_util.h b/tensorflow/contrib/session_bundle/test_util.h
index dd0fc8d1c0..f0d41ce5a4 100644
--- a/tensorflow/contrib/session_bundle/test_util.h
+++ b/tensorflow/contrib/session_bundle/test_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
#include <string>
@@ -35,4 +35,4 @@ string TestSrcDirPath(const string& relative_path);
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 870f504d10..f5a9299d26 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -236,7 +236,7 @@ class SingleEvaluationTest(test.TestCase):
def _prepareCheckpoint(self, checkpoint_path):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
- saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
+ saver = saver_lib.Saver()
with self.test_session() as sess:
sess.run(init_op)
saver.save(sess, checkpoint_path)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
index 930df2414b..7b609ae96b 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
@@ -45,32 +45,55 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_):
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
# Make a selfadjoint, positive definite.
a_np = np.dot(a_np.T, a_np)
+ # jacobi preconditioner
+ jacobi_np = np.zeros_like(a_np)
+ jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = (1.0 /
+ a_np.diagonal())
rhs_np = np.random.uniform(
low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
+ x_np = np.zeros_like(rhs_np)
tol = 1e-6 if dtype_ == np.float64 else 1e-3
max_iter = 20
with self.test_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
+ x = constant_op.constant(x_np)
+ jacobi = constant_op.constant(jacobi_np)
else:
a = array_ops.placeholder(dtype_)
rhs = array_ops.placeholder(dtype_)
+ x = array_ops.placeholder(dtype_)
+ jacobi = array_ops.placeholder(dtype_)
operator = util.create_operator(a)
- cg_graph = linear_equations.conjugate_gradient(
- operator, rhs, tol=tol, max_iter=max_iter)
- if use_static_shape_:
- cg_val = sess.run(cg_graph)
- else:
- cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np})
- norm_r0 = np.linalg.norm(rhs_np)
- norm_r = np.sqrt(cg_val.gamma)
- self.assertLessEqual(norm_r, tol * norm_r0)
- # Validate that we get an equally small residual norm with numpy
- # using the computed solution.
- r_np = rhs_np - np.dot(a_np, cg_val.x)
- norm_r_np = np.linalg.norm(r_np)
- self.assertLessEqual(norm_r_np, tol * norm_r0)
+ preconditioners = [None, util.identity_operator(a),
+ util.create_operator(jacobi)]
+ cg_results = []
+ for preconditioner in preconditioners:
+ cg_graph = linear_equations.conjugate_gradient(
+ operator, rhs, preconditioner=preconditioner,
+ x=x, tol=tol, max_iter=max_iter)
+ if use_static_shape_:
+ cg_val = sess.run(cg_graph)
+ else:
+ cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np, x: x_np,
+ jacobi: jacobi_np})
+ norm_r0 = np.linalg.norm(rhs_np)
+ norm_r = np.linalg.norm(cg_val.r)
+ self.assertLessEqual(norm_r, tol * norm_r0)
+ # Validate that we get an equally small residual norm with numpy
+ # using the computed solution.
+ r_np = rhs_np - np.dot(a_np, cg_val.x)
+ norm_r_np = np.linalg.norm(r_np)
+ self.assertLessEqual(norm_r_np, tol * norm_r0)
+ cg_results.append(cg_val)
+ # Validate that we get same results using identity_preconditioner
+ # and None
+ self.assertEqual(cg_results[0].i, cg_results[1].i)
+ self.assertAlmostEqual(cg_results[0].gamma, cg_results[1].gamma)
+ self.assertAllClose(cg_results[0].r, cg_results[1].r, rtol=tol)
+ self.assertAllClose(cg_results[0].x, cg_results[1].x, rtol=tol)
+ self.assertAllClose(cg_results[0].p, cg_results[1].p, rtol=tol)
return [test_conjugate_gradient]
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 1566984b27..12e94369cb 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -63,6 +63,41 @@ class UtilTest(test.TestCase):
def testCreateOperatorUnknownShape(self):
self._testCreateOperator(False)
+ def _testIdentityOperator(self, use_static_shape_):
+ for dtype in np.float32, np.float64:
+ a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
+ x_np = np.array([[2.], [-3.]], dtype=dtype)
+ y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
+ with self.test_session() as sess:
+ if use_static_shape_:
+ a = constant_op.constant(a_np, dtype=dtype)
+ x = constant_op.constant(x_np, dtype=dtype)
+ y = constant_op.constant(y_np, dtype=dtype)
+ else:
+ a = array_ops.placeholder(dtype)
+ x = array_ops.placeholder(dtype)
+ y = array_ops.placeholder(dtype)
+ id_op = util.identity_operator(a)
+ ax = id_op.apply(x)
+ aty = id_op.apply_adjoint(y)
+ op_shape = ops.convert_to_tensor(id_op.shape)
+ if use_static_shape_:
+ op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty])
+ else:
+ op_shape_val, ax_val, aty_val = sess.run(
+ [op_shape, ax, aty], feed_dict={a: a_np,
+ x: x_np,
+ y: y_np})
+ self.assertAllEqual(op_shape_val, [3, 2])
+ self.assertAllClose(ax_val, x_np)
+ self.assertAllClose(aty_val, y_np)
+
+ def testIdentityOperator(self):
+ self._testIdentityOperator(True)
+
+ def testIdentityOperatorUnknownShape(self):
+ self._testIdentityOperator(False)
+
def testL2Norm(self):
with self.test_session():
x_np = np.array([[2], [-3.], [5.]])
diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py
index 8cba56eba6..4dfaa97ac9 100644
--- a/tensorflow/contrib/solvers/python/ops/linear_equations.py
+++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py
@@ -27,10 +27,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import linalg_ops
def conjugate_gradient(operator,
rhs,
+ preconditioner=None,
+ x=None,
tol=1e-4,
max_iter=20,
name="conjugate_gradient"):
@@ -55,6 +58,15 @@ def conjugate_gradient(operator,
vector with the result of applying the operator to `x`, i.e. if
`operator` represents matrix `A`, `apply` should return `A * x`.
rhs: A rank-1 `Tensor` of shape `[N]` containing the right-hand size vector.
+ preconditioner: An object representing a linear operator, see `operator`
+ for detail. The preconditioner should approximate the inverse of `A`.
+ An efficient preconditioner could dramatically improve the rate of
+ convergence. If `preconditioner` represents matrix `M`(`M` approximates
+ `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate
+ `A^{-1}x`. For this to be useful, the cost of applying `M` should be
+ much lower than computing `A^{-1}` directly.
+ x: A rank-1 `Tensor` of shape `[N]` containing the initial guess for the
+ solution.
tol: A float scalar convergence tolerance.
max_iter: An integer giving the maximum number of iterations.
name: A name scope for the operation.
@@ -65,35 +77,51 @@ def conjugate_gradient(operator,
- x: A rank-1 `Tensor` of shape `[N]` containing the computed solution.
- r: A rank-1 `Tensor` of shape `[M]` containing the residual vector.
- p: A rank-1 `Tensor` of shape `[N]`. `A`-conjugate basis vector.
- - gamma: \\(||r||_2^2\\)
+ - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when
+ `preconditioner=None`.
"""
# ephemeral class holding CG state.
cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"])
def stopping_criterion(i, state):
- return math_ops.logical_and(i < max_iter, state.gamma > tol)
+ return math_ops.logical_and(i < max_iter,
+ linalg_ops.norm(state.r) > tol)
- # TODO(rmlarsen): add preconditioning
def cg_step(i, state):
z = operator.apply(state.p)
alpha = state.gamma / util.dot(state.p, z)
x = state.x + alpha * state.p
r = state.r - alpha * z
- gamma = util.l2norm_squared(r)
- beta = gamma / state.gamma
- p = r + beta * state.p
+ if preconditioner is None:
+ gamma = util.dot(r, r)
+ beta = gamma / state.gamma
+ p = r + beta * state.p
+ else:
+ q = preconditioner.apply(r)
+ gamma = util.dot(r, q)
+ beta = gamma / state.gamma
+ p = q + beta * state.p
return i + 1, cg_state(i + 1, x, r, p, gamma)
with ops.name_scope(name):
n = operator.shape[1:]
rhs = array_ops.expand_dims(rhs, -1)
- gamma0 = util.l2norm_squared(rhs)
- tol = tol * tol * gamma0
- x = array_ops.expand_dims(
- array_ops.zeros(
- n, dtype=rhs.dtype.base_dtype), -1)
+ if x is None:
+ x = array_ops.expand_dims(
+ array_ops.zeros(
+ n, dtype=rhs.dtype.base_dtype), -1)
+ r0 = rhs
+ else:
+ x = array_ops.expand_dims(x, -1)
+ r0 = rhs - operator.apply(x)
+ if preconditioner is None:
+ p0 = r0
+ else:
+ p0 = preconditioner.apply(r0)
+ gamma0 = util.dot(r0, p0)
+ tol = tol * linalg_ops.norm(r0)
i = constant_op.constant(0, dtype=dtypes.int32)
- state = cg_state(i=i, x=x, r=rhs, p=rhs, gamma=gamma0)
+ state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0)
_, state = control_flow_ops.while_loop(stopping_criterion, cg_step,
[i, state])
return cg_state(
diff --git a/tensorflow/contrib/solvers/python/ops/util.py b/tensorflow/contrib/solvers/python/ops/util.py
index 777e0c185d..96947e8eea 100644
--- a/tensorflow/contrib/solvers/python/ops/util.py
+++ b/tensorflow/contrib/solvers/python/ops/util.py
@@ -45,6 +45,23 @@ def create_operator(matrix):
apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True))
+def identity_operator(matrix):
+ """Creates a linear operator from a rank-2 identity tensor."""
+
+ linear_operator = collections.namedtuple(
+ "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"])
+ shape = matrix.get_shape()
+ if shape.is_fully_defined():
+ shape = shape.as_list()
+ else:
+ shape = array_ops.shape(matrix)
+ return linear_operator(
+ shape=shape,
+ dtype=matrix.dtype,
+ apply=lambda v: v,
+ apply_adjoint=lambda v: v)
+
+
# TODO(rmlarsen): Measure if we should just call matmul.
def dot(x, y):
return math_ops.reduce_sum(math_ops.conj(x) * y)
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index c8b4e472c9..360e7dbe75 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -105,8 +105,8 @@ class SparsemaxLossTest(test.TestCase):
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
- self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
- half_atol=1e-2, half_rtol=5e-3)
+ self.assertAllCloseAccordingToType(
+ np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
def _test_constant_add(self, dtype, random, use_gpu):
@@ -116,17 +116,17 @@ class SparsemaxLossTest(test.TestCase):
q = np.zeros((test_obs, 10))
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
- _, tf_loss_zpc = self._tf_sparsemax_loss(
- z + c, q, dtype, use_gpu
- )
+ _, tf_loss_zpc = self._tf_sparsemax_loss(z + c, q, dtype, use_gpu)
- _, tf_loss_z = self._tf_sparsemax_loss(
- z, q, dtype, use_gpu
- )
+ _, tf_loss_z = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
- self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
- float_atol=5e-6, float_rtol=5e-6,
- half_atol=1e-2, half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(
+ tf_loss_zpc,
+ tf_loss_z,
+ float_atol=5e-6,
+ float_rtol=5e-6,
+ half_atol=1e-2,
+ half_rtol=1e-2)
def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 4"""
@@ -170,10 +170,7 @@ class SparsemaxLossTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
err = gradient_checker.compute_gradient_error(
- logits, z.shape,
- loss_op, (test_obs, ),
- x_init_value=z, delta=1e-9
- )
+ logits, z.shape, loss_op, (test_obs,), x_init_value=z, delta=1e-9)
self.assertLess(err, 1e-4)
@@ -192,8 +189,8 @@ class SparsemaxLossTest(test.TestCase):
tf_grad = loss_grad_op.eval()
np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
- self.assertAllCloseAccordingToType(np_grad, tf_grad,
- half_atol=1e-2, half_rtol=5e-3)
+ self.assertAllCloseAccordingToType(
+ np_grad, tf_grad, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_grad, loss_grad_op)
def _test_dtype(self, dtype):
@@ -220,5 +217,6 @@ class SparsemaxLossTest(test.TestCase):
def testDouble(self):
self._test_dtype('float64')
-if __name__ == "__main__":
+
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 82d36ee9cb..259e62bd86 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -83,8 +83,8 @@ class SparsemaxTest(test.TestCase):
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
p_sparemax = self._np_sparsemax(z).astype(dtype)
- self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
@@ -111,9 +111,8 @@ class SparsemaxTest(test.TestCase):
p_expected = np.zeros((test_obs, 10), dtype=dtype)
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
- tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
- (1 / epsilon) * z, dtype, use_gpu
- )
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax((1 / epsilon) * z,
+ dtype, use_gpu)
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
self.assertShapeEqual(p_expected, tf_sparsemax_op)
@@ -123,16 +122,12 @@ class SparsemaxTest(test.TestCase):
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
- _, tf_sparsemax_zpc = self._tf_sparsemax(
- z + c, dtype, use_gpu
- )
+ _, tf_sparsemax_zpc = self._tf_sparsemax(z + c, dtype, use_gpu)
- _, tf_sparsemax_z = self._tf_sparsemax(
- z, dtype, use_gpu
- )
+ _, tf_sparsemax_z = self._tf_sparsemax(z, dtype, use_gpu)
- self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3)
def _test_permutation(self, dtype, random, use_gpu):
"""check sparsemax proposition 3"""
@@ -143,12 +138,11 @@ class SparsemaxTest(test.TestCase):
per = random.permutation(10)
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
- z[i, per].reshape(1, -1), dtype, use_gpu
- )
+ z[i, per].reshape(1, -1), dtype, use_gpu)
p_expected = p[i, per].reshape(1, -1)
- self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ p_expected, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_expected, tf_sparsemax_op)
def _test_diffrence(self, dtype, random, use_gpu):
@@ -166,18 +160,14 @@ class SparsemaxTest(test.TestCase):
continue
self.assertTrue(
- 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
- "0 <= %.10f <= %.10f" % (
- p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
- )
- )
+ 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
+ '0 <= %.10f <= %.10f' % (p[val, j] - p[val, i],
+ z[val, j] - z[val, i] + etol))
def _test_two_dimentional(self, dtype, random, use_gpu):
"""check two dimentation sparsemax case"""
t = np.linspace(-2, 2, test_obs, dtype=dtype)
- z = np.vstack([
- t, np.zeros(test_obs, dtype=dtype)
- ]).T
+ z = np.vstack([t, np.zeros(test_obs, dtype=dtype)]).T
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
@@ -196,10 +186,7 @@ class SparsemaxTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
err = gradient_checker.compute_gradient_error(
- logits, z.shape,
- sparsemax_op, z.shape,
- x_init_value=z, delta=1e-9
- )
+ logits, z.shape, sparsemax_op, z.shape, x_init_value=z, delta=1e-9)
self.assertLess(err, 1e-4)
@@ -248,5 +235,6 @@ class SparsemaxTest(test.TestCase):
def testDouble(self):
self._test_dtype('float64')
-if __name__ == "__main__":
+
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tensor_forest/kernels/data_spec.h b/tensorflow/contrib/tensor_forest/kernels/data_spec.h
index 05590d6992..0a3abe56df 100644
--- a/tensorflow/contrib/tensor_forest/kernels/data_spec.h
+++ b/tensorflow/contrib/tensor_forest/kernels/data_spec.h
@@ -15,8 +15,8 @@
// This is a surrogate for using a proto, since it doesn't seem to be possible
// to use protos in a dynamically-loaded/shared-linkage library, which is
// what is used for custom ops in tensorflow/contrib.
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
#include <unordered_map>
#include "tensorflow/core/lib/strings/numbers.h"
@@ -138,4 +138,4 @@ class TensorForestDataSpec {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_DATA_SPEC_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
index 35f9fb7eaf..dad9df4898 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
#include <limits>
@@ -307,4 +307,4 @@ void GetParentWeightedMean(float leaf_sum, const float* leaf_data,
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_CORE_OPS_TREE_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h
index 4bd1f06c72..2e7368dc12 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
#include <string>
#include <vector>
@@ -70,4 +70,4 @@ class CandidateGraphRunner {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
index bf88216d66..cced26b903 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
@@ -88,4 +88,4 @@ class DecisionTreeResource : public ResourceBase {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
index 3f03c2d05b..85ce7b825b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
@@ -104,4 +104,4 @@ struct CandidateEvalatorCollection {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
index dacf033d99..0d6712e9e5 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
#include <vector>
@@ -98,4 +98,4 @@ class FertileStatsResource : public ResourceBase {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h
index 2ae3a79b3d..4ae48179af 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
#include <vector>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
@@ -78,4 +78,4 @@ class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
index 3e41ab50b9..02c0fc687f 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
#include <unordered_map>
#include <vector>
@@ -316,7 +316,7 @@ class DenseClassificationGrowStats : public ClassificationStats {
void PackToProto(FertileSlot* slot) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
- LeafStat* right_stats) const;
+ LeafStat* right_stats) const override;
protected:
void ClassificationAddSplitStats() override {
@@ -383,7 +383,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
void PackToProto(FertileSlot* slot) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
- LeafStat* right_stats) const;
+ LeafStat* right_stats) const override;
protected:
void ClassificationAddSplitStats() override {
@@ -609,4 +609,4 @@ class LeastSquaresRegressionGrowStats : public GrowStats {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index 14cb19d36f..bf0fb92450 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -21,8 +21,6 @@ namespace tensorflow {
namespace tensorforest {
namespace {
-const int32 SPARSE_DEFAULT = 0;
-
bool DecideInequalityTest(const decision_trees::InequalityTest& test,
float value) {
float bias = test.threshold().float_value();
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index e3d4edbf8a..eafad6b591 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
#include <ctime>
#include <unordered_map>
#include "google/protobuf/any.pb.h"
@@ -123,4 +123,4 @@ class TensorDataSet {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
index 0309ec1de9..44ec09c50e 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -89,4 +89,4 @@ class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
index 946a648f22..cc4ec8dc9e 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
@@ -146,4 +146,4 @@ class LeafModelOperatorFactory {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h
index 97a9d8d096..b0ed949424 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/params.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -28,5 +28,4 @@ float ResolveParam(const DepthDependentParam& param, int32 depth);
} // namespace tensorforest
} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
index 6c21c0bd34..ad52f89fad 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
#include <vector>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
@@ -128,6 +128,4 @@ class AnyCollectionCreator : public CollectionCreator {
} // namespace tensorforest
} // namespace tensorflow
-
-
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h
index 8e002d0414..e6140065bb 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -47,4 +47,4 @@ float WeightedSmoothedGini(float sum, float square, int num_classes);
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
index b6e543b96f..289c81e9d5 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
+#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
@@ -71,4 +71,4 @@ class TestableDataSet : public TensorDataSet {
} // namespace tensorforest
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
+#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index f92b60b03a..eeb308fee8 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -1,11 +1,14 @@
-# -*- python -*-
# Description:
-# provide tensorrt operators and converter package
+# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
+# and provide TensorRT operators and converter package.
+# APIs are meant to change over time.
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
+exports_files(["LICENSE"])
+
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"//tensorflow:tensorflow.bzl",
@@ -15,11 +18,32 @@ load(
"tf_py_wrap_cc",
"tf_cc_test",
"tf_kernel_library",
+ "tf_cuda_cc_test",
"tf_custom_op_py_library",
"tf_copts",
)
+load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
-
+tf_cuda_cc_test(
+ name = "tensorrt_test_cc",
+ size = "small",
+ srcs = ["tensorrt_test.cc"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
tf_custom_op_library(
name = "python/ops/_trt_engine_op.so",
@@ -55,7 +79,6 @@ cc_library(
]
)
-
tf_kernel_library(
name = "trt_engine_op_kernel",
srcs = [
@@ -88,7 +111,6 @@ tf_gen_op_libs(
]
)
-
cc_library(
name="trt_logging",
srcs = [
@@ -112,7 +134,6 @@ tf_gen_op_wrapper_py(
],
)
-
tf_custom_op_py_library(
name = "trt_engine_op_loader",
srcs = ["python/ops/trt_engine_op.py"],
@@ -135,8 +156,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":trt_ops_py",
- ":trt_convert_py",
-
+ ":trt_convert_py",
],
)
@@ -209,7 +229,6 @@ tf_custom_op_library(
],
)
-
# Library for the segmenting portion of TensorRT operation creation
cc_library(
name = "segment",
@@ -241,7 +260,6 @@ tf_cc_test(
],
)
-
# Library for the node-level conversion portion of TensorRT operation creation
filegroup(
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
new file mode 100644
index 0000000000..e11522ea5b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda.h"
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace {
+
+class Logger : public nvinfer1::ILogger {
+ public:
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
+ switch (severity) {
+ case Severity::kINFO:
+ LOG(INFO) << msg;
+ break;
+ case Severity::kWARNING:
+ LOG(WARNING) << msg;
+ break;
+ case Severity::kINTERNAL_ERROR:
+ case Severity::kERROR:
+ LOG(ERROR) << msg;
+ break;
+ default:
+ break;
+ }
+ }
+};
+
+class ScopedWeights {
+ public:
+ ScopedWeights(float value) : value_(value) {
+ w.type = nvinfer1::DataType::kFLOAT;
+ w.values = &value_;
+ w.count = 1;
+ }
+ const nvinfer1::Weights& get() { return w; }
+
+ private:
+ float value_;
+ nvinfer1::Weights w;
+};
+
+const char* kInputTensor = "input";
+const char* kOutputTensor = "output";
+
+// Creates a network to compute y=2x+3.
+nvinfer1::IHostMemory* CreateNetwork() {
+ Logger logger;
+ nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
+ ScopedWeights weights(2.0);
+ ScopedWeights bias(3.0);
+
+ nvinfer1::INetworkDefinition* network = builder->createNetwork();
+ // Add the input.
+ auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
+ nvinfer1::DimsCHW{1, 1, 1});
+ EXPECT_NE(input, nullptr);
+ // Add the hidden layer.
+ auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
+ EXPECT_NE(layer, nullptr);
+ // Mark the output.
+ auto output = layer->getOutput(0);
+ output->setName(kOutputTensor);
+ network->markOutput(*output);
+ // Build the engine
+ builder->setMaxBatchSize(1);
+ builder->setMaxWorkspaceSize(1 << 10);
+ auto engine = builder->buildCudaEngine(*network);
+ EXPECT_NE(engine, nullptr);
+ // Serialize the engine to create a model, then close everything.
+ nvinfer1::IHostMemory* model = engine->serialize();
+ network->destroy();
+ engine->destroy();
+ builder->destroy();
+ return model;
+}
+
+// Executes the network.
+void Execute(nvinfer1::IExecutionContext& context, const float* input,
+ float* output) {
+ const nvinfer1::ICudaEngine& engine = context.getEngine();
+
+ // We have two bindings: input and output.
+ ASSERT_EQ(engine.getNbBindings(), 2);
+ const int input_index = engine.getBindingIndex(kInputTensor);
+ const int output_index = engine.getBindingIndex(kOutputTensor);
+
+ // Create GPU buffers and a stream
+ void* buffers[2];
+ ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
+ ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
+
+ // Copy the input to the GPU, execute the network, and copy the output back.
+ //
+ // Note that since the host buffer was not created as pinned memory, these
+ // async copies are turned into sync copies. So the following synchronization
+ // could be removed.
+ ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
+ cudaMemcpyHostToDevice, stream));
+ context.enqueue(1, buffers, stream, nullptr);
+ ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
+ cudaMemcpyDeviceToHost, stream));
+ cudaStreamSynchronize(stream);
+
+ // Release the stream and the buffers
+ cudaStreamDestroy(stream);
+ ASSERT_EQ(0, cudaFree(buffers[input_index]));
+ ASSERT_EQ(0, cudaFree(buffers[output_index]));
+}
+
+TEST(TensorrtTest, BasicFunctions) {
+ // Create the network model.
+ nvinfer1::IHostMemory* model = CreateNetwork();
+ // Use the model to create an engine and then an execution context.
+ Logger logger;
+ nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
+ nvinfer1::ICudaEngine* engine =
+ runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
+ model->destroy();
+ nvinfer1::IExecutionContext* context = engine->createExecutionContext();
+
+ // Execute the network.
+ float input = 1234;
+ float output;
+ Execute(*context, &input, &output);
+ EXPECT_EQ(output, input * 2 + 3);
+
+ // Destroy the engine.
+ context->destroy();
+ engine->destroy();
+ runtime->destroy();
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD
index 346c03067d..198da0203a 100644
--- a/tensorflow/contrib/tpu/profiler/BUILD
+++ b/tensorflow/contrib/tpu/profiler/BUILD
@@ -44,13 +44,22 @@ cc_library(
],
)
+cc_library(
+ name = "version",
+ hdrs = ["version.h"],
+ visibility = ["//visibility:public"],
+)
+
tf_cc_binary(
name = "capture_tpu_profile",
- srcs = ["capture_tpu_profile.cc"],
+ srcs = [
+ "capture_tpu_profile.cc",
+ ],
visibility = ["//visibility:public"],
deps = [
":dump_tpu_profile",
":tpu_profiler_proto_cc",
+ ":version",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b67f2f47a7..7373d0e17c 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h"
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
+#include "tensorflow/contrib/tpu/profiler/version.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/init_main.h"
@@ -46,12 +47,14 @@ string GetCurrentTimeStampAsString() {
return s;
}
-ProfileResponse Profile(const string& service_addr, int duration_ms) {
+ProfileResponse Profile(const string& service_addr, int duration_ms,
+ const ProfileOptions& opts) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
request.add_tools("input_pipeline");
request.add_tools("overview_page");
+ *request.mutable_opts() = opts;
std::cout << "Limiting the number of trace events to " << kMaxEvents
<< std::endl;
::grpc::ClientContext context;
@@ -75,6 +78,7 @@ int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
int FLAGS_duration_ms = 2000;
+ bool FLAGS_include_dataset_ops = true;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
@@ -82,8 +86,13 @@ int main(int argc, char** argv) {
"Path of TensorBoard log directory e.g. /tmp/tb_log"),
tensorflow::Flag("duration_ms", &FLAGS_duration_ms,
"Duration of tracing in ms. Default is 2000ms."),
+ tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
+ "Set to false to profile longer TPU device traces."),
};
+ std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
+ << std::endl;
+
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
@@ -93,8 +102,10 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
int duration_ms = FLAGS_duration_ms;
+ tensorflow::ProfileOptions opts;
+ opts.set_include_dataset_ops(FLAGS_include_dataset_ops);
tensorflow::ProfileResponse response =
- tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms);
+ tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts);
// Use the current timestamp as the run name.
tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString();
TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index 120a38b6c2..b842951eb2 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -149,8 +149,10 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
// Dumps profile data to <logdir>/plugins/profile/<run>/.
string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
+
// Ignore computation_graph for now.
- if (response.encoded_trace().empty()) {
+ const bool empty_trace = response.encoded_trace().empty();
+ if (empty_trace) {
*os << "No trace event is collected." << std::endl;
} else {
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
@@ -163,13 +165,12 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir,
response.op_profile(), os));
}
- if (!response.tool_data().empty()) {
+ if (!empty_trace && !response.tool_data().empty()) {
for (const auto& tool_data : response.tool_data()) {
TF_RETURN_IF_ERROR(
DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
}
}
- TF_RETURN_IF_ERROR(DumpGraphEvents(logdir, run, response, os));
return Status::OK();
}
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
index 65b92aa418..25b958bcfe 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
+#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
+#define TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -35,4 +35,4 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
} // namespace tpu
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
+#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 7970c20a26..846db13329 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -42,8 +42,9 @@ def main(unused_argv=None):
if not FLAGS.service_addr or not FLAGS.logdir:
sys.exit('service_addr and logdir must be provided.')
executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE)
+ logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
cmd = [executable_path]
- cmd.append('--logdir='+FLAGS.logdir)
+ cmd.append('--logdir='+logdir)
cmd.append('--service_addr='+FLAGS.service_addr)
cmd.append('--duration_ms='+str(FLAGS.duration_ms))
subprocess.call(cmd)
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 179d29602b..9219663831 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,16 +20,12 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.3.0-a1'
+_VERSION = '1.4.3-a2'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
]
-REQUIRED_PACKAGES = [
- 'tensorflow >= 1.2.0',
-]
-
setup(
name='cloud_tpu_profiler',
version=_VERSION.replace('-', ''),
@@ -45,13 +41,12 @@ setup(
entry_points={
'console_scripts': CONSOLE_SCRIPTS,
},
- install_requires=REQUIRED_PACKAGES,
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
- 'Development Status :: 3 - Alpha',
+ 'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Education',
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 5440bbbfdd..2094294baa 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -61,6 +61,11 @@ message OpMetricsResult {
message OpMetricsDbResult {
// A bunch of OpMetricsResults.
repeated OpMetricsResult metrics_db = 1;
+ // The total host infeed-enqueue duration in picoseconds.
+ optional uint64 total_host_infeed_enq_duration_ps = 2;
+ // The total of the difference between the start times of two
+ // consecutive infeed-enqueues (per host) in picoseconds.
+ optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3;
}
// Result proto for StepInfo.
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
index bf30d2ce09..f3f3302ceb 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -13,6 +13,14 @@ service TPUProfiler {
}
}
+message ProfileOptions {
+ // We don't collect the dataset ops by default for better trace-viewer
+ // scalability. The caller can mannually set this field to include the ops.
+ bool include_dataset_ops = 1;
+
+ // next-field: 2
+}
+
message ProfileRequest {
// In future, the caller will be able to customize when profiling starts and
// stops. For now, it collects `duration_ms` milliseconds worth of data.
@@ -25,10 +33,13 @@ message ProfileRequest {
// required profiling tools name such as "input_pipeline_analyzer" etc
repeated string tools = 3;
+ // Optional profiling options that control how a TF session will be profiled.
+ ProfileOptions opts = 4;
+
// In future, the caller will indicate which TF session is being profiled, and
// only data relating to that program will be returned. For now, we assume
// all activity during the profiling period is relevant.
- // next-field: 4
+ // next-field: 5
}
message ProfileToolData {
diff --git a/tensorflow/contrib/tpu/profiler/trace_events_to_json.h b/tensorflow/contrib/tpu/profiler/trace_events_to_json.h
index 992eae43d9..3bd76dd01c 100644
--- a/tensorflow/contrib/tpu/profiler/trace_events_to_json.h
+++ b/tensorflow/contrib/tpu/profiler/trace_events_to_json.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
+#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
+#define TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -29,4 +29,4 @@ string TraceEventsToJson(const Trace &trace);
} // namespace tpu
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
+#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_TRACE_EVENTS_TO_JSON_H_
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
new file mode 100644
index 0000000000..0f645a5492
--- /dev/null
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -0,0 +1,21 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
+#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
+
+#define TPU_PROFILER_VERSION "1.4.3"
+
+#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 33e47f674d..1c970655d0 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import platform
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
if platform.system() != "Windows":
@@ -40,6 +41,63 @@ if platform.system() != "Windows":
del op # Unused
# The gradient of a cross replica sum is also a cross-replica sum.
return gen_tpu_ops.cross_replica_sum(grad)
+
+ # This extra type checking exists to give a more helpful error message in
+ # the common case that uint8 and int64 values are infed. Remove when both
+ # types are supported.
+
+ _SUPPORTED_INFEED_DTYPES = set([
+ dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32
+ ])
+
+ def infeed_dequeue(dtype, shape, name=None):
+ """A placeholder op for a value that will be fed into the computation.
+
+ Args:
+ dtype: A `tf.DType`. The type of elements in the tensor.
+ shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `dtype`.
+ A tensor that will be provided using the infeed mechanism.
+
+ Raises:
+ TypeError: If 'dtype` is not a supported infeed type.
+ """
+ if dtype not in _SUPPORTED_INFEED_DTYPES:
+ raise TypeError(
+ "{} is not a supported TPU infeed type. Supported types are: "
+ "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
+
+ return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
+
+ # pylint: disable=redefined-outer-name
+ def infeed_dequeue_tuple(dtypes, shapes, name=None):
+ """A placeholder op for values fed into the TPU simultaneously as a tuple.
+
+ Args:
+ dtypes: A list of `tf.DType`s that has length `>= 1`.
+ The element types of each element in `outputs`.
+ shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
+ The shapes of each tensor in `outputs`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of `Tensor` objects of type `dtypes`.
+ A list of tensors that will be provided using the infeed mechanism.
+
+ Raises:
+ TypeError: If a type in 'dtypes` is not a supported infeed type.
+ """
+ for dtype in dtypes:
+ if dtype not in _SUPPORTED_INFEED_DTYPES:
+ raise TypeError(
+ "{} is not a supported TPU infeed type. Supported types are: "
+ "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
+ return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
+ # pylint: enable=redefined-outer-name
+
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index bb35f4ece6..2ae3a26a85 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
-
"""TPUEstimator class."""
from __future__ import absolute_import
@@ -24,6 +23,7 @@ from contextlib import contextmanager
import copy
import threading
import time
+import traceback
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
@@ -60,7 +60,6 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
-
_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
@@ -86,28 +85,28 @@ def _create_global_step(graph):
initializer=init_ops.zeros_initializer(),
trainable=False,
use_resource=True,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.GLOBAL_STEP])
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
def _create_or_get_iterations_per_loop():
graph = ops.get_default_graph()
- iter_vars = graph.get_collection(_TPU_ESTIMATOR)
+ collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
+ iter_vars = graph.get_collection(collection_name)
if len(iter_vars) == 1:
return iter_vars[0]
elif len(iter_vars) > 1:
raise RuntimeError('Multiple iterations_per_loop_var in collection.')
with ops.colocate_with(training_util.get_global_step()):
- with variable_scope.variable_scope(_TPU_ESTIMATOR,
- reuse=variable_scope.AUTO_REUSE):
+ with variable_scope.variable_scope(
+ _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
return variable_scope.get_variable(
_ITERATIONS_PER_LOOP_VAR,
initializer=init_ops.zeros_initializer(),
shape=[],
dtype=dtypes.int32,
trainable=False,
- collections=[_TPU_ESTIMATOR],
+ collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
use_resource=True)
@@ -241,9 +240,9 @@ class _TPUContext(object):
return self._eval_batch_size
return None
- global_batch_size = (self._train_batch_size if
- mode == model_fn_lib.ModeKeys.TRAIN
- else self._eval_batch_size)
+ global_batch_size = (
+ self._train_batch_size
+ if mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size)
# On TPU
if self.is_input_sharded_per_core():
return global_batch_size // self.num_cores
@@ -290,8 +289,9 @@ class _TPUContext(object):
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
mode = self._assert_mode()
- master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL
- else run_config.master)
+ master = (
+ run_config.evaluation_master
+ if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
if master in _LOCAL_MASTERS:
return None
@@ -318,6 +318,7 @@ class _TPUContext(object):
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
master = self.master_job
+
def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name
assert _sentinal is None
if core_id is not None and host_id is not None:
@@ -332,19 +333,23 @@ class _TPUContext(object):
if core_id is not None:
host_id = core_id / 8
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
+
return _placement_function
@property
def tpu_device_placement_function(self):
master = self.master_job
job_device = '' if master is None else ('/job:%s' % master)
+
def _placement_function(i):
return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)
+
return _placement_function
@property
def tpu_ordinal_function(self):
"""Returns the TPU ordinal fn."""
+
def _tpu_ordinal_function(index):
"""Return the TPU ordinal associated with a shard.
@@ -357,6 +362,7 @@ class _TPUContext(object):
The ordinal of the TPU device the shard's infeed should be placed on.
"""
return index % 8
+
return _tpu_ordinal_function
@@ -370,14 +376,16 @@ class _SIGNAL(object):
STOP = -2
-class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
- 'mode',
- 'predictions',
- 'loss',
- 'train_op',
- 'eval_metrics',
- 'export_outputs',
- 'scaffold_fn'])):
+class TPUEstimatorSpec(
+ collections.namedtuple('TPUEstimatorSpec', [
+ 'mode',
+ 'predictions',
+ 'loss',
+ 'train_op',
+ 'eval_metrics',
+ 'export_outputs',
+ 'scaffold_fn'
+ ])):
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
@@ -387,7 +395,7 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- ${tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -415,111 +423,116 @@ class TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
"""Creates a validated `TPUEstimatorSpec` instance."""
if eval_metrics is not None:
_EvalMetrics.validate(eval_metrics)
- return super(TPUEstimatorSpec, cls).__new__(cls,
- mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op,
- eval_metrics=eval_metrics,
- export_outputs=export_outputs,
- scaffold_fn=scaffold_fn)
+ return super(TPUEstimatorSpec, cls).__new__(
+ cls,
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metrics=eval_metrics,
+ export_outputs=export_outputs,
+ scaffold_fn=scaffold_fn)
def as_estimator_spec(self):
"""Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
eval_metric_ops = _EvalMetrics.to_metric_metric_ops_for_cpu(
self.eval_metrics)
scaffold = self.scaffold_fn() if self.scaffold_fn else None
- return model_fn_lib.EstimatorSpec(mode=self.mode,
- predictions=self.predictions,
- loss=self.loss,
- train_op=self.train_op,
- eval_metric_ops=eval_metric_ops,
- export_outputs=self.export_outputs,
- scaffold=scaffold)
+ return model_fn_lib.EstimatorSpec(
+ mode=self.mode,
+ predictions=self.predictions,
+ loss=self.loss,
+ train_op=self.train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=self.export_outputs,
+ scaffold=scaffold)
+
+
+class _OpQueueContext(object):
+ """Manages work queue and thread for a infeed/outfeed thread."""
+
+ def __init__(self, name, target, args):
+ self._name = name
+ self._queue = Queue.Queue()
+ args = (self,) + args
+ self._thread = threading.Thread(name=name, target=target, args=args)
+ self._thread.daemon = True
+ self._thread.start()
+
+ def stop(self):
+ self._queue.put(_SIGNAL.STOP)
+
+ def send_next_batch_signal(self, iterations):
+ self._queue.put(iterations)
+
+ def read_iteration_counts(self):
+ while True:
+ signal = self._queue.get(block=True)
+ logging.debug('%s read signal %s', self._name, signal)
+ if signal == _SIGNAL.STOP:
+ logging.info('%s received signal, stopping.', self._name)
+ return
+ yield signal
+ def join(self):
+ logging.info('Shutting down %s thread.' % self._name)
+ self.stop()
+ self._thread.join()
-class _InfeedOutfeedThreadBaseController(object):
- """This wraps the infeed/outfeed thread and stops when Estimator finishes."""
- def __init__(self, thd):
- self._signal_queue = Queue.Queue()
- thd.daemon = True
- thd.start()
- self._thd = thd
+class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
+ """A Session hook setting up the TPU initialization, infeed, and outfeed.
- def block_and_get_signal(self):
- return self._signal_queue.get()
+ This hook does two major things:
+ 1. initialize and shutdown TPU system.
+ 2. launch and join the threads for infeed enqueue and (optional) outfeed
+ dequeue.
+ """
- def send_next_batch_signal(self, signal=_SIGNAL.NEXT_BATCH):
- self._signal_queue.put(signal)
+ def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
+ self._master_job = ctx.master_job
+ self._enqueue_ops = enqueue_ops
+ self._dequeue_ops = dequeue_ops
+ self._initial_infeed_sleep_secs = (
+ ctx.config.tpu_config.initial_infeed_sleep_secs)
+ self._session_cancel_timer = None
- def join(self):
- self._signal_queue.put(_SIGNAL.STOP)
- self._thd.join()
+ self._feed_error = None
+ self._finished = False
+ def begin(self):
+ logging.info('TPU job name %s', self._master_job)
+ self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
+ self._init_op = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
-class _OutfeedThreadController(_InfeedOutfeedThreadBaseController):
- """This wraps the outfeed thread and stops when Estimator finishes."""
+ def _log_error(self, session, error):
+ """Log an infeed or outfeed error.
- def __init__(self, session, dequeue_ops):
- super(_OutfeedThreadController, self).__init__(
- threading.Thread(target=self._execute_dequeue_ops,
- args=(session, dequeue_ops)))
+ This logs a short error message immediately, and schedules a timer to
+ emit the full stack trace and error message after a short period of time.
+ If the main session has terminated by the time the timer triggers, we
+ assume the real source of the error was from the main session and avoid
+ emitting a stack trace for the infeed.
- def _execute_dequeue_ops(self, session, dequeue_ops):
- count = 0
- while True:
- signal = self.block_and_get_signal()
- if signal == _SIGNAL.STOP:
- logging.info('Stop outfeed thread.')
- return
+ Args:
+ session: `tf.Session`, session to be terminated
+ error: exception that triggered logging.
+ """
+ logging.warning(
+ '\n\n'
+ 'Error occurred during infeed/outfeed. This may be due to a compile '
+ 'error in the main session. Waiting for a short time for the main '
+ 'session to come back.\n\n%s', error)
- iterations = signal
- for i in range(iterations):
- logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
- session.run(dequeue_ops)
- count += 1
+ self._feed_error = traceback.format_exc()
- def join(self):
- logging.info('Waiting for Outfeed Thread to exit.')
- super(_OutfeedThreadController, self).join()
-
-
-class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
- """This wraps the infeed thread and stops when Estimator finishes."""
-
- def __init__(self, session, enqueue_ops, initial_infeed_sleep_secs):
- super(_InfeedThreadController, self).__init__(
- threading.Thread(
- target=self._input_thread_fn_for_loading,
- args=(session, enqueue_ops, initial_infeed_sleep_secs)))
-
- def _input_thread_fn_for_loading(self, session, enqueue_ops,
- initial_infeed_sleep_secs):
- count = 0
- if initial_infeed_sleep_secs:
- logging.info('Infeed thread sleeping for %d seconds.',
- initial_infeed_sleep_secs)
- time.sleep(initial_infeed_sleep_secs)
- logging.info('Infeed thread starting after sleep')
- try:
- while True:
- signal = self._signal_queue.get()
- if signal == _SIGNAL.STOP:
- logging.info('Stop Infeed input thread.')
- return
-
- if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- # Enqueue batches for next loop.
- session.run(enqueue_ops)
- else:
- iterations = signal
- for i in range(iterations):
- logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
- session.run(enqueue_ops)
- count += 1
+ # If we've already encountered a feed error, don't schedule another
+ # cancellation op.
+ if self._session_cancel_timer:
+ return
- except Exception: # pylint: disable=broad-except
+ def _cancel_session():
# Close the session to avoid the main thread from hanging. If input
# pipeline triggers any error, the infeed thread dies but the main thread
# for TPU computation waits for the infeed enqueue forever. Close the
@@ -534,77 +547,94 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
# exception in the main thread, instead of the expected compile error.
# User code that depends on having the proper exception type will
# therefore be confused.
- logging.error(
- 'Failed running infeed, closing session.\n'
- 'You may see an exception from your main session after this. '
- 'Sleep for 2 minutes before close Session from infeed thread to '
- 'allow the main thread returning an error first, if any.',
- exc_info=1
- )
- time.sleep(120)
- logging.error('Closing the failed session.')
- session.close()
-
- def join(self):
- logging.info('Waiting for Infeed Thread to exit.')
- super(_InfeedThreadController, self).join()
-
-
-class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
- """A Session hook setting up the TPU initialization, infeed, and outfeed.
-
- This hook does two major things:
- 1. initialize and shutdown TPU system.
- 2. launch and join the threads for infeed enqueue and (optional) outfeed
- dequeue.
- """
+ time.sleep(5)
+
+ # If the main session is still running, the infeed/outfeed errors are
+ # legitimate, and should be logged.
+ if not self._finished:
+ logging.error('Feed error: %s', self._feed_error)
+ logging.error('Closing session. A RuntimeError should follow.')
+ session.close()
+
+ self._session_cancel_timer = threading.Thread(target=_cancel_session)
+ self._session_cancel_timer.daemon = True
+ self._session_cancel_timer.start()
+
+ def _run_infeed(self, queue_ctx, session):
+ logging.info('Starting infeed thread controller.')
+ if self._initial_infeed_sleep_secs:
+ logging.info('%s thread sleeping for %d seconds.', self._name,
+ self._initial_infeed_sleep_secs)
+ time.sleep(self._initial_infeed_sleep_secs)
+ logging.info('%s thread starting after sleep', self._name)
- def __init__(self, ctx, enqueue_ops, dequeue_ops=None):
- self._master_job = ctx.master_job
- self._enqueue_ops = enqueue_ops
- self._dequeue_ops = dequeue_ops
- self._initial_infeed_sleep_secs = (
- ctx.config.tpu_config.initial_infeed_sleep_secs)
+ try:
+ if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+ for _ in queue_ctx.read_iteration_counts():
+ session.run(self._enqueue_ops)
+ else:
+ for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+ for i in xrange(steps):
+ logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
+ session.run(self._enqueue_ops)
+ logging.debug('Infeed thread finished, shutting down.')
+ except Exception as e: # pylint: disable=broad-except
+ self._log_error(session, e)
- def begin(self):
- logging.info('TPU job name %s', self._master_job)
- self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_op = [tpu.initialize_system(job=self._master_job)]
- self._finalize_op = [tpu.shutdown_system(job=self._master_job)]
+ def _run_outfeed(self, queue_ctx, session):
+ logging.info('Starting outfeed thread controller.')
+ try:
+ for count, steps in enumerate(queue_ctx.read_iteration_counts()):
+ for i in xrange(steps):
+ logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
+ session.run(self._dequeue_ops)
+ except Exception as e: # pylint: disable=broad-except
+ self._log_error(session, e)
def after_create_session(self, session, coord):
logging.info('Init TPU system')
- session.run(self._init_op,
- options=config_pb2.RunOptions(timeout_in_ms=5*60*1000))
+ session.run(
+ self._init_op,
+ options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
logging.info('Start infeed thread controller')
- self._infeed_thd_controller = _InfeedThreadController(
- session, self._enqueue_ops, self._initial_infeed_sleep_secs)
+ self._infeed_controller = _OpQueueContext(
+ name='InfeedController', target=self._run_infeed, args=(session,))
if self._dequeue_ops is not None:
logging.info('Start outfeed thread controller')
- self._outfeed_thd_controller = _OutfeedThreadController(
- session, self._dequeue_ops)
+ self._outfeed_controller = _OpQueueContext(
+ name='OutfeedController', target=self._run_outfeed, args=(session,))
def before_run(self, run_context):
+ if self._feed_error:
+ logging.warning('Feed error occurred, terminating session.')
+ run_context.request_stop()
+ return
+
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+ self._infeed_controller.send_next_batch_signal(iterations)
- self._infeed_thd_controller.send_next_batch_signal(iterations)
if self._dequeue_ops is not None:
# TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
- logging.info(
- 'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
- self._outfeed_thd_controller.send_next_batch_signal(iterations)
+ logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
+ iterations)
+ self._outfeed_controller.send_next_batch_signal(iterations)
def end(self, session):
+ if self._session_cancel_timer:
+ logging.warning('Feed error occurred; waiting for message.')
+ self._session_cancel_timer.join()
+
+ self._finished = True
logging.info('Stop infeed thread controller')
- self._infeed_thd_controller.join()
+ self._infeed_controller.join()
if self._dequeue_ops is not None:
logging.info('Stop output thread controller')
- self._outfeed_thd_controller.join()
+ self._outfeed_controller.join()
logging.info('Shutdown TPU system.')
session.run(self._finalize_op)
@@ -675,8 +705,8 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
run_context.request_stop()
else:
iterations = self._next_iterations(global_step, self._last_step)
- self._iterations_per_loop_var.load(iterations,
- session=run_context.session)
+ self._iterations_per_loop_var.load(
+ iterations, session=run_context.session)
class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
@@ -697,8 +727,8 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
self._iterations_per_loop_var.load(self._num_steps, session=session)
-def generate_per_core_enqueue_ops_fn_for_host(
- ctx, input_fn, inputs_structure_recorder):
+def generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,
+ inputs_structure_recorder):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
@@ -728,9 +758,9 @@ def generate_per_core_enqueue_ops_fn_for_host(
per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
- per_host_sharded_inputs,
- tpu_ordinal_function=ctx.tpu_ordinal_function)
+ per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
return per_host_enqueue_ops
+
return enqueue_ops_fn, captured_infeed_queue
@@ -747,8 +777,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
features, labels = inputs
else:
features, labels = inputs, None
- inputs_structure_recorder.validate_and_record_structure(
- features, labels)
+ inputs_structure_recorder.validate_and_record_structure(features, labels)
unsharded_tensor_list = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels))
@@ -762,9 +791,9 @@ def generate_per_host_enqueue_ops_fn_for_host(
per_host_enqueue_ops = (
infeed_queue.split_inputs_and_generate_enqueue_ops(
- unsharded_tensor_list,
- placement_function=lambda x: device))
+ unsharded_tensor_list, placement_function=lambda x: device))
return per_host_enqueue_ops
+
return enqueue_ops_fn, captured_infeed_queue
@@ -814,6 +843,7 @@ class _InputPipeline(object):
def validate_and_record_structure(self, features, labels):
"""Validates and records the structure of features` and `labels`."""
+
def _extract_key_names(tensor_or_dict):
if tensor_or_dict is None:
return []
@@ -841,8 +871,8 @@ class _InputPipeline(object):
flattened_inputs = []
if self._feature_names:
# We need a fixed ordering for enqueueing and dequeueing.
- flattened_inputs.extend([features[name]
- for name in self._feature_names])
+ flattened_inputs.extend(
+ [features[name] for name in self._feature_names])
else:
flattened_inputs.append(features)
@@ -869,11 +899,11 @@ class _InputPipeline(object):
ValueError: If the number of expected tensors from `flattened_inputs`
mismatches the recorded structure.
"""
- expected_num_features = (len(self._feature_names) if self._feature_names
- else 1)
+ expected_num_features = (
+ len(self._feature_names) if self._feature_names else 1)
if self._has_labels:
- expected_num_labels = (len(self._label_names) if self._label_names
- else 1)
+ expected_num_labels = (
+ len(self._label_names) if self._label_names else 1)
else:
expected_num_labels = 0
@@ -894,8 +924,8 @@ class _InputPipeline(object):
if expected_num_labels == 0:
unflattened_label = None
elif self._label_names:
- unflattened_label = dict(zip(self._label_names,
- flattened_inputs[expected_num_features:]))
+ unflattened_label = dict(
+ zip(self._label_names, flattened_inputs[expected_num_features:]))
else:
# Single tensor case.
unflattened_label = flattened_inputs[expected_num_features]
@@ -960,8 +990,9 @@ class _InputPipeline(object):
self._ctx, self._input_fn, self._inputs_structure_recorder))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- enqueue_ops.append(_wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
@@ -978,8 +1009,9 @@ class _InputPipeline(object):
self._batch_axis, host_device))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
- enqueue_ops.append(_wrap_computation_in_while_loop(
- device=host_device, op_fn=enqueue_ops_fn))
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
@@ -1065,6 +1097,7 @@ class _ModelFnWrapper(object):
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
+
return train_step, captured_scaffold_fn
def convert_to_single_tpu_eval_step(self, dequeue_fn):
@@ -1113,6 +1146,7 @@ class _ModelFnWrapper(object):
with ops.control_dependencies([outfeed_ops]):
return math_ops.add(total_loss, loss)
+
return eval_step, eval_metrics, captured_scaffold_fn
def _call_model_fn(self, features, labels):
@@ -1137,10 +1171,9 @@ class _ModelFnWrapper(object):
kwargs['params'] = params
if 'params' not in model_fn_args:
- raise ValueError(
- 'model_fn ({}) does not include params argument, '
- 'required by TPUEstimator to pass batch size as '
- 'params[\'batch_size\']'.format(self._model_fn))
+ raise ValueError('model_fn ({}) does not include params argument, '
+ 'required by TPUEstimator to pass batch size as '
+ 'params[\'batch_size\']'.format(self._model_fn))
batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
if batch_size_for_model_fn is not None:
@@ -1347,8 +1380,9 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
examples_per_sec = self._batch_size * elapsed_steps / elapsed_time
if self._summary_writer is not None:
- example_summary = Summary(value=[Summary.Value(
- tag='examples_sec', simple_value=examples_per_sec)])
+ example_summary = Summary(value=[
+ Summary.Value(tag='examples_sec', simple_value=examples_per_sec)
+ ])
self._summary_writer.add_summary(example_summary, global_step)
logging.info('examples/sec: %g', examples_per_sec)
@@ -1487,9 +1521,8 @@ class TPUEstimator(estimator_lib.Estimator):
'`config` must be provided with type `tpu_config.RunConfig`')
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
- raise ValueError(
- '{} are reserved keys but existed in params {}.'.format(
- _RESERVED_PARAMS_KEYS, params))
+ raise ValueError('{} are reserved keys but existed in params {}.'.format(
+ _RESERVED_PARAMS_KEYS, params))
if use_tpu:
if train_batch_size is None:
@@ -1570,8 +1603,9 @@ class TPUEstimator(estimator_lib.Estimator):
if max_steps is not None:
util_lib.check_positive_integer(max_steps, 'Train max_steps')
- return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps,
- max_steps)]
+ return [
+ _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
+ ]
def _convert_eval_steps_to_hooks(self, steps):
with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
@@ -1639,6 +1673,7 @@ class TPUEstimator(estimator_lib.Estimator):
# `features` in `model_fn` signature.
def _input_fn():
return input_fn(**kwargs)
+
return _input_fn
def _augment_model_fn(self, model_fn, batch_axis):
@@ -1694,9 +1729,10 @@ class TPUEstimator(estimator_lib.Estimator):
total_loss, eval_metric_ops, scaffold = _eval_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
iterations_per_loop_var = _create_or_get_iterations_per_loop()
- mean_loss = math_ops.div(
- total_loss,
- math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype))
+ mean_loss = math_ops.div(total_loss,
+ math_ops.cast(
+ iterations_per_loop_var,
+ dtype=total_loss.dtype))
# Creates a dummy metric update_op for all metrics. Estimator expects
# all metrics in eval_metric_ops have update_op and calls them one by
@@ -1724,6 +1760,7 @@ class TPUEstimator(estimator_lib.Estimator):
evaluation_hooks=hooks,
eval_metric_ops=eval_metric_ops,
scaffold=scaffold)
+
return _model_fn
@@ -1736,15 +1773,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn))
def multi_tpu_eval_steps_on_single_shard():
- return training_loop.repeat(iterations_per_loop_var,
- single_tpu_eval_step,
- [_ZERO_LOSS],
- name='loop')
+ return training_loop.repeat(
+ iterations_per_loop_var,
+ single_tpu_eval_step, [_ZERO_LOSS],
+ name='loop')
- (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard,
- inputs=[],
- num_shards=num_cores,
- outputs_from_all_shards=False)
+ (loss,) = tpu.shard(
+ multi_tpu_eval_steps_on_single_shard,
+ inputs=[],
+ num_shards=num_cores,
+ outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, eval_metric_ops, scaffold
@@ -1761,14 +1799,14 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
def multi_tpu_train_steps_on_single_shard():
return training_loop.repeat(
iterations_per_loop_var,
- single_tpu_train_step,
- [_INITIAL_LOSS],
+ single_tpu_train_step, [_INITIAL_LOSS],
name=b'loop')
- (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard,
- inputs=[],
- num_shards=num_cores,
- outputs_from_all_shards=False)
+ (loss,) = tpu.shard(
+ multi_tpu_train_steps_on_single_shard,
+ inputs=[],
+ num_shards=num_cores,
+ outputs_from_all_shards=False)
scaffold = _get_scaffold(captured_scaffold_fn)
return loss, scaffold
@@ -1776,6 +1814,7 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
def _wrap_computation_in_while_loop(device, op_fn):
"""Wraps the ops generated by `op_fn` in tf.while_loop."""
+
def computation(i):
with ops.control_dependencies(op_fn()):
return i + 1
@@ -1787,7 +1826,8 @@ def _wrap_computation_in_while_loop(device, op_fn):
iterations = array_ops.identity(iterations_per_loop_var)
return control_flow_ops.while_loop(
lambda i: i < iterations,
- computation, [constant_op.constant(0)], parallel_iterations=1)
+ computation, [constant_op.constant(0)],
+ parallel_iterations=1)
def _validate_tpu_training_graph():
@@ -1800,8 +1840,9 @@ def _validate_tpu_training_graph():
# Check if there is atleast one CrossReplicaSum operation in the graph
# This should be introduced by using the CrossShardOptimizer wrapper
- cross_replica_sum_ops = [o for o in operations
- if o.type == _CROSS_REPLICA_SUM_OP]
+ cross_replica_sum_ops = [
+ o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
+ ]
if not cross_replica_sum_ops:
raise ValueError(
'CrossShardOptimizer must be used for model training on TPUs.')
@@ -1848,9 +1889,11 @@ def _get_scaffold(captured_scaffold_fn):
if scaffold:
wrapped_finalize = scaffold.finalize
+
def _finalize():
with _CapturingContext('Inside Scaffold.finalize'):
wrapped_finalize()
+
scaffold.finalize = _finalize
return scaffold
@@ -1865,9 +1908,8 @@ class _CapturingContext(control_flow_ops.ControlFlowContext):
def AddOp(self, op): # pylint: disable=invalid-name
for c in op.inputs:
if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access
- raise ValueError(
- '{}: Op {} depends on TPU computation {}, '
- 'which is not allowed.'.format(self._message, op, c))
+ raise ValueError('{}: Op {} depends on TPU computation {}, '
+ 'which is not allowed.'.format(self._message, op, c))
def __enter__(self):
# pylint: disable=protected-access
diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md
new file mode 100644
index 0000000000..ca1255b16b
--- /dev/null
+++ b/tensorflow/contrib/tpu/tpu_estimator.md
@@ -0,0 +1,241 @@
+# Using the Estimator API with TPUs
+
+
+This document describes how to train a TensorFlow model on TPUs using the
+Estimator API. If you are interested in the hardware itself, check out the
+[Cloud TPU documentation](https://cloud.google.com/tpu/docs).
+
+The TPU Estimator simplifies running models on a Cloud TPU by automatically
+handling numerous low-level hardware-specific details
+
+[TOC]
+
+## Introduction to Estimator
+
+[TensorFlow
+tutorials](https://www.tensorflow.org/extend/estimators) cover the Estimator
+API. At a high-level, the Estimator API provides:
+
+* `Estimator.train()` - train a model on a given input for a fixed number of
+ steps.
+* `Estimator.evaluate()` - evaluate the model on a test set.
+* `Estimator.predict()` - run inference using the trained model.
+* `Estimator.export_savedmodel()` - export your model for serving.
+
+In addition, `Estimator` includes default behavior common to training jobs,
+such as saving and restoring checkpoints, creating summaries for TensorBoard,
+etc.
+
+`Estimator` requires you to write a `model_fn` and an `input_fn`, which
+correspond to the model and input portions of your TensorFlow graph.
+
+The following code demonstrates using `TPUEstimator` with MNIST example to
+handle training:
+
+ def model_fn(features, labels, mode, params):
+ """A simple CNN."""
+ del params # unused
+
+ input_layer = tf.reshape(features, [-1, 28, 28, 1])
+ conv1 = tf.layers.conv2d(
+ inputs=input_layer, filters=32, kernel_size=[5, 5], padding="same",
+ activation=tf.nn.relu)
+ pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
+ conv2 = tf.layers.conv2d(
+ inputs=pool1, filters=64, kernel_size=[5, 5],
+ padding="same", activation=tf.nn.relu)
+ pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
+ pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
+ dense = tf.layers.dense(inputs=pool2_flat, units=128, activation=tf.nn.relu)
+ dropout = tf.layers.dropout(
+ inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
+ logits = tf.layers.dense(inputs=dropout, units=10)
+ onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
+
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+
+ learning_rate = tf.train.exponential_decay(
+ FLAGS.learning_rate, tf.train.get_global_step(), 100000, 0.96)
+
+ optimizer = tpu_optimizer.CrossShardOptimizer(
+ tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
+
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+ return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)
+
+
+ def get_input_fn(filename):
+ """Returns an `input_fn` for train and eval."""
+
+ def input_fn(params):
+ """An input_fn to parse 28x28 images from filename using tf.data."""
+ batch_size = params["batch_size"]
+
+ def parser(serialized_example):
+ """Parses a single tf.Example into image and label tensors."""
+ features = tf.parse_single_example(
+ serialized_example,
+ features={
+ "image_raw": tf.FixedLenFeature([], tf.string),
+ "label": tf.FixedLenFeature([], tf.int64),
+ })
+ image = tf.decode_raw(features["image_raw"], tf.uint8)
+ image.set_shape([28 * 28])
+ # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
+ image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
+ label = tf.cast(features["label"], tf.int32)
+ return image, label
+
+ dataset = tf.contrib.data.TFRecordDataset(
+ filename, buffer_size=FLAGS.dataset_reader_buffer_size)
+ dataset = dataset.map(parser).cache().repeat().batch(batch_size)
+ images, labels = dataset.make_one_shot_iterator().get_next()
+ # set_shape to give inputs statically known shapes.
+ images.set_shape([batch_size, 28 * 28])
+ labels.set_shape([batch_size])
+ return images, labels
+ return input_fn
+
+
+ def main(unused_argv):
+
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ run_config = tpu_config.RunConfig(
+ master=FLAGS.master,
+ model_dir=FLAGS.model_dir,
+ session_config=tf.ConfigProto(
+ allow_soft_placement=True, log_device_placement=True),
+ tpu_config=tpu_config.TPUConfig(FLAGS.iterations, FLAGS.num_shards),)
+
+ estimator = tpu_estimator.TPUEstimator(
+ model_fn=model_fn,
+ use_tpu=FLAGS.use_tpu,
+ train_batch_size=FLAGS.batch_size,
+ eval_batch_size=FLAGS.batch_size,
+ config=run_config)
+
+ estimator.train(input_fn=get_input_fn(FLAGS.train_file),
+ max_steps=FLAGS.train_steps)
+
+
+Although this code is quite simple by appearance, there are some new
+concepts to learn for using `TPU`s. The next section will cover the most
+important details.
+
+## New Concepts Related to TPU/TPUEstimator
+
+TF programs run with `TPU Estimator` use an [in-graph
+replication](https://www.tensorflow.org/deploy/distributed) approach.
+
+In-graph replication (also known as single-session replication) differs from
+the between-graph replication (also known as multi-session replication)
+training typically used in distributed TensorFlow. The major
+differences include:
+
+1. The TensorFlow Session master is not local anymore. The user python program
+ creates one single graph that is replicated across all the cores in the Cloud
+ TPU. The typical configuration today sets the TensorFlow session master to be
+ the first worker.
+
+1. The input pipeline is placed on remote hosts (instead of local) to ensure the
+ training examples can be fed as fast as possible to TPU system. All queue-based
+ input pipelines do not work effectively. Dataset (tf.data) is
+ required.
+
+1. Workers in the TPU system operate in synchronous fashion, and each perform
+ the same step at the same time.
+
+Regarding programming model, _"The programmer picks a (large) batch size B and
+writes the program (and sets hyperparameters) based on that batch size. The
+system distributes the computation across the available devices."
+
+To align these, `TPUEstimator` wraps the computation (the `model_fn`) and
+distributes it to all available TPU chips.
+
+To summarize:
+
+- The `input_fn` models the input pipeline running on remote host CPU. Use
+ `tf.data` to program the input Ops. `input_fn` is expected to be invoked
+ multiple times when using TPU pods. Each handles one device's input of the
+ global batch. The shard batch size should be retrieved from
+ `params['batch_size']`. We plan to provide better abstraction about the
+ sharding mechanism for `tf.data` to remove the `params['batch_size']`.
+
+- The `model_fn` models the computation which will be replicated and distributed
+ to all TPU chips. It should only contains ops that are supported by TPUs.
+
+## Convert from Vanilla Estimator to TPUEstimator
+
+It is always recommended to port a small, simple model first to make sure that
+you are familiar with the basic concepts of `TPUEstimator` and test end-to-end
+behavior. Once your simple model runs, gradually add more functionality.
+In addition, there are several sample models, available at
+[github.com/tensorflow/tpu-demos](https://github.com/tensorflow/tpu-demos).
+
+To convert your code from the vanilla `Estimator` class to use TPUs, change the
+following (note some of the details may change over time):
+
+- Switch from `tf.estimator.RunConfig` to `tf.contrib.tpu.RunConfig`.
+- Set the `TPUConfig` (part of the `tf.contrib.tpu.RunConfig`) to specify the
+ `iterations_per_loop`, number of iterations to run on the TPU device for one
+ `session.run` call (per training loop), and `num_shards`, the number of shards
+ (typically the number of TPU cores you’re running on). TPUs run a number of
+ iterations of the training loop before returning to host. Until all iterations
+ on the TPU device are run, no checkpoints or summaries will be saved. In the
+ future, we’ll choose a reasonable default.
+- In `model_fn`, use `tf.contrib.tpu.CrossShardOptimizer` to wrap your
+ optimizer. Example:
+
+ optimizer = tpu_optimizer.CrossShardOptimizer(
+ tf.train.GradientDescentOptimizer(learning_rate=learning_rate))
+
+- Switch from `tf.estimator.Estimator` to `tf.contrib.tpu.TPUEstimator`.
+
+The default `RunConfig` will save summaries for TensorBoard every 100 steps and
+write checkpoints every 10 minutes.
+
+
+## FAQ
+
+### Why `tf.data` is Required for the Input Pipeline
+
+There are two reasons:
+
+1. The user code runs on the client, while the TPU computation is executed on
+ the `worker`. Input pipeline ops must be placed on the remote worker for
+ good performance. Only `tf.data` (Dataset) supports this.
+
+1. In order to amortize the TPU launch cost, the model train step is wrapped in
+ a `tf.while_loop`, such that one `Session.run` actually runs many iterations
+ for one train loop. To remove network back and forth, the input pipeline
+ in the future will be wrapped in a `tf.while_loop` and be placed on the
+ corresponding `worker`. Withou this, unnecessary network latency becomes
+ the performance bottleneck for models with short training-step times, or in
+ environments where network latency is higher. Only `tf.data` can be wrapped
+ by a `tf.while_loop`.
+
+
+### How to add other CPU Ops into Graph
+As `model_fn` only allows TPU Ops for computation, the easier workaround to add
+CPU Ops into Graph is:
+
+1. Create a [SessionRunHook](https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook).
+1. Modify the graph in the `def begin(self)`,
+1. Pass the hook to `TPUEstimator.train`.
+
+### Running On GCP Cloud TPUs
+To run your models on GCP Cloud TPUs refer to the [Cloud Documentation](https://cloud.google.com/tpu/docs/tutorials/mnist).
+Refer to this link for all [Cloud TPU documentation](https://cloud.google.com/tpu/docs).
+
+
+### Profiling
+You can profile the `worker` by using instructions as spcified in the [Cloud TPU Tools](https://cloud.google.com/tpu/docs/cloud-tpu-tools).
+
+
+### Is `int64` supported?
+`int64` is not supported by TPU. Cast to int32 if applicable. The only exception
+is global step, which relies on `assign_add`. `int64` support for global step
+is added to ensure checkpoint compatibility between `TPUEstimator` and non-TPU
+`Estimator`.
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 2a0ef0e6b3..dbdbb08a82 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -53,7 +53,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor1 = sparse_tensor.SparseTensor(
array_ops.constant(ind1, dtypes.int64),
array_ops.constant(val1, dtypes.int64),
- array_ops.constant(shape1, dtypes.int64))
+ array_ops.placeholder_with_default(shape1, shape=[2]))
ind2 = np.array([
[0, 0, 1],
[0, 1, 0],
@@ -68,7 +68,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor2 = sparse_tensor.SparseTensor(
array_ops.constant(ind2, dtypes.int64),
array_ops.constant(val2, dtypes.int64),
- array_ops.constant(shape2, dtypes.int64))
+ array_ops.placeholder_with_default(shape2, shape=[3]))
sp_tensor3 = sparse_tensor.SparseTensor(
array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64),
array_ops.constant([7, 15, 2], dtypes.int64),
@@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
def testNotAMultiple(self):
num_unroll = 3 # Not a divisor of value_length -
# so padding would have been necessary.
+
+ # Use placeholder_with_default in sequences to make sure we get runtime
+ # error instead of shape inference error
+ sequences = {
+ "seq1": array_ops.placeholder_with_default(self.sequences["seq1"],
+ shape=(None, 5)),
+ "seq2": array_ops.placeholder_with_default(self.sequences["seq2"],
+ shape=(None, 4, 2)),
+ "seq3": self.sequences["seq3"],
+ "seq4": self.sequences["seq4"],
+ }
+
with self.test_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
@@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
with coord.stop_on_exception():
next_batch = sqss.batch_sequences_with_states(
input_key=self.key,
- input_sequences=self.sequences,
+ input_sequences=sequences,
input_context=self.context,
input_length=3,
initial_states=self.initial_states,
@@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch2=expected_seq4_batch2)
+class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
+
+ def setUp(self):
+ self._prev_value = ops._USE_C_API
+ ops._USE_C_API = True
+ super(BatchSequencesWithStatesTestWithCApi, self).setUp()
+
+ def tearDown(self):
+ super(BatchSequencesWithStatesTestWithCApi, self).tearDown()
+ ops._USE_C_API = self._prev_value
+
+
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 80de0f6eb7..fdfd27d6a4 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -40,7 +40,7 @@ PARAM_RE = re.compile(r"""
((?P<val>[^,\[]*) # single value: "a" or None
|
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
- ($|,)""", re.VERBOSE)
+ ($|,\s*)""", re.VERBOSE)
def _parse_fail(name, var_type, value, values):
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 28e4b4d01e..16397622ed 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -55,7 +55,7 @@ class HParamsTest(test.TestCase):
self.assertEqual(12, hparams.aaa)
self.assertEqual(2.0, hparams.b)
self.assertEqual('relu6', hparams.c_c)
- hparams.parse('c_c=relu4,b=-2.0e10')
+ hparams.parse('c_c=relu4, b=-2.0e10')
self.assertDictEqual({
'aaa': 12,
'b': -2.0e10,
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h
index 6518e7a10f..61fc6f36f7 100644
--- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
+#ifndef TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
+#define TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
#include <string>
@@ -31,4 +31,4 @@ Status ConvertConstantsToImmutable(const string& in_graph_filename,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
+#endif // TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
index 38a84ffb10..80a5d07ea4 100644
--- a/tensorflow/contrib/verbs/BUILD
+++ b/tensorflow/contrib/verbs/BUILD
@@ -99,7 +99,7 @@ cc_library(
alwayslink = 1,
)
-tf_cuda_library(
+cc_library(
name = "rdma_rendezvous_mgr",
srcs = ["rdma_rendezvous_mgr.cc"],
hdrs = ["rdma_rendezvous_mgr.h"],
@@ -114,7 +114,7 @@ tf_cuda_library(
],
)
-cc_library(
+tf_cuda_library(
name = "rdma_mgr",
srcs = ["rdma_mgr.cc"],
hdrs = ["rdma_mgr.h"],
@@ -141,6 +141,8 @@ tf_cuda_library(
"//conditions:default": [],
}),
deps = [
+ ":grpc_verbs_client",
+ ":verbs_service_proto_cc",
":verbs_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
index 7c1c8ea459..1b99f4ce4f 100644
--- a/tensorflow/contrib/verbs/README.md
+++ b/tensorflow/contrib/verbs/README.md
@@ -24,66 +24,144 @@ The design is based on TensorFlow r1.0. An RDMA path is added between servers fo
During the server setup, an RDMA manager is created to manage low-level RDMA components such as RDMA channel and RDMA adapter, an RDMA rendezvous manager is created to oversee send/recv operations between servers. Following the distributed TensorFlow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
-TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each RDMA channel, representing a RDMA connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned.
+TensorFlow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for RDMA operations where pinned memory is required. Few remedies are possible:
+1. The memory is pinned, transfered, then unpinned for each and every tensor to be transferred. This incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow.
+2. Buffer is pre-allocated and pinned for each tensor. This incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former.
+3. Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), there is a smart way to benefit from the TensorFlow allocation theme which is mostly pool based, i.e allocators pre-allocate a large memory block, and allocate the tensors from there. By attaching a custom Visitor to relevant alloactors, we can do a single registration of the entire memory block, which zeros the registration overhead. Once the block is registered, each new tensor allocated will be at a registred address, which will allow us to do direct RDMA writes to it.
-When a tensor is prepared for transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via RDMA write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory.
-![TensorFlow RDMA path](./design_diagram.png)
+For best performance, we will adopt HKUST 0 copies approach in our solution. This means:
+
+1. Tensor writes will be done directly from the source tensor to the **result** tensor, with no memory copies in between. This should be done for all DMAable tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct).
+2. Non DMAable tensors (CanMemCopy == false) will be serialized to a TensorProto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver.
+3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU **proxy** buffer on the receiver side, and then copied to GPU by the receiver.
-The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption.
## Design details
-### RDMA components
+### Terminology
-* **RDMA adapter:** The base for RDMA communications. It may contain multiple channels and buffers. It is responsible for handling various incoming RDMA messages.
-* **RDMA channel:** Responsible for RDMA connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors.
-* **RDMA buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers.
-* **RDMA manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
-* **RDMA rendezvous manager:** manages multiple rdma rendezvous.
-* **RDMA rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
+* **Sender** - The node which sends the tensor.
+* **Receiver** - The node which receives the tensor.
+* **Result tensor** - The destination tensor, allocated on its appropriate device.
+* **Proxy tensor** - A CPU allocated tensor, which will be used in the case where the result tensor cannot be RDMA written to directly (GPU direct is disabled or not available). The RDMA write will therefore be done to the proxy tensor, and afterwards we will do a manual local copy from it to the result tensor.
-### The SEND operation
+### Messages
-In TensorFlow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node.
+* RDMA_MESSAGE_TENSOR_REQUEST
+* RDMA_MESSAGE_META_DATA_RESPONSE
+* RDMA_MESSAGE_TENSOR_RE_REQUEST
+### Transport protocol
-### The RECV operation
+The tensor transfer process is initiated when the receiver requests a tensor. In code it is done by calling **Rendezvous::Recv()** or **Rendezvous::RecvAsync()**. The TensorFlow base implementation handles the case where the requested tensor is located on the same node. The more interesting case where the requested tensor is located on a remote node (receiver != sender) is to be handled in a derivation of the pure virtual **BaseRemoteRendezvous::RecvFromRemoteAsync()**. TensorFlow provides a default GRPC based implementation which comes in the vanilla version but suffers in scalability when running large models. Our RDMA based implementation presumes to be more scalable. HKUST's contrib GDR implementation is more scalable than GRPC, and less scalable than ours, only because we did our evolution based on it.
-When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding RDMA buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the RDMA buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transmission can be deferred is when the buffer is still being used by an on-going transmission.
+Our entry point is the implementation of **RdmaRemoteRendezvous::RecvFromRemoteAsync()**, located in rdma_rendezvous_mgr.cc. The implementation creates a new **RdmaTensorRequest** object, keyed by request index (uint32_t), stores it in a list of pending requests, and calls its **Start()** method. The **Start()** method basically does 2 things:
-### Three types of RDMA buffers
+1. Allocate the result tensor (and the proxy tensor if required).
+2. Send a **RDMA_MESSAGE_TENSOR_REQUEST** to the sender, containing the address of the destination tensor (result/proxy) for RDMA write.
-* **Message buffer:** responsible for sending message only.
-* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer.
-* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer.
+In order to allocate the result and proxy tensors, we need to know the tensor's meta-data, i.e. shape and data-type for DMAable tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data at each step, we store a local meta-data cache per tensor, which will only be update upon changes. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. The sender is responsible to detect changes in the meta-data, and update the receiver. In order for the sender to know that the meta-data had changed, each **RDMA_MESSAGE_TENSOR_REQUEST** will contain the meta-data that the receiver had grabbed from the local cache. The sender will then compare the meta-data from the message to the tensor's new meta-data.
-### RDMA packet format
+When the sender receives an **RDMA_MESSAGE_TENSOR_REQUEST**, it will create a new **RdmaTensorResponse** object for the given request message, store it in a list of pending responses, and will invoke its **Start()** method. The **Start()** method does the following:
-|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer|
+1. Grab the source tensor from the local table (In code, **RecvLocalAsync()**).
+2. If the source tensor is not DMAable, serialize it to a TensorProto.
+3. If the source tensor is located on a device which cannot be DMA written from, copy it to CPU.
+4. If it is the first time this tensor is requested, or if the tensor's meta-data changed:
+ 1. Clone the tensor's data to be sent later.
+ 2. Send a **RDMA_MESSAGE_META_DATA_RESPONSE** containing the new meta-data.
+5. Otherwise:
+ 1. RDMA write the tensor (or TensorProto) to the destination address and rkey specified in the request message. The immediate value for the write will be the request index.
-### Six types of RDMA messages
-* RDMA_MESSAGE_ACK
-* RDMA_MESSAGE_BUFFER_IDLE
-* RDMA_MESSAGE_BUFFER_REQUEST
-* RDMA_MESSAGE_BUFFER_RESPONSE
-* RDMA_MESSAGE_TENSOR_REQUEST
-* RDMA_MESSAGE_TENSOR_WRITE
-
-### Actions upon receiving RDMA messages
-* RDMA_MESSAGE_ACK
- * sender: mark local ack buffer idle.
- * receiver: mark remote message buffer idle, send next item.
-* RDMA_MESSAGE_BUFFER_IDLE
- * sender: mark local message buffer idle, send next item.
- * receiver: send ack, set remote tensor buffer idle, send next item.
-* RDMA_MESSAGE_BUFFER_REQUEST
- * sender: mark local message buffer idle, send next item.
- * receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE.
-* RDMA_MESSAGE_BUFFER_RESPONSE
- * sender: mark local message buffer idle, send next item.
- * receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item.
-* RDMA_MESSAGE_TENSOR_REQUEST
- * sender: mark local message buffer idle, send next item.
- * receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item.
-* RDMA_MESSAGE_TENSOR_WRITE
- * sender: mark local message buffer idle, send next item.
- * receiver: run callback.
+
+When the receiver receives the **RDMA_MESSAGE_META_DATA_RESPONSE**, it will locate the relevant **RdmaTensorRequest** using the request index specified in the message, and invoke its **RecvTensorMetaData()** which does the following:
+
+1. Update the local meta-data cache.
+2. Reallocate the result/proxy tensors.
+3. Re-send the tensor request. For tracability, the new message has a different name: **RDMA_MESSAGE_TENSOR_RE_REQUEST**.
+
+When the sender receives a **RDMA_MESSAGE_TENSOR_RE_REQUEST**, it will locate the relevant **RdmaTensorResponse** using the request index specified in the message, and invoke its **Resume()** method, which will RDMA write the contents of the tensor that was cloned earlier, to the new remote address specified in the re-request.
+
+When the receiver receives the RDMA write, it will locate the relevant **RdmaTensorRequest** using the request index which is the immediate value. It will then invoke its **RecvTensorContent()** which does the following:
+
+1. Proxy copy/deserialize if required.
+2. Invoke the done callback.
+3. Deallocate the result/proxy tensors and remove the request from the pending list.
+
+![alt text](verbs_with_0_copies.png "Transport protocol")
+
+### Additional design notes
+
+1. When the sender receives a tensor request, the source tensor may or may not be ready yet. The situation is handled through a process of tag matching:
+ * If the request arrives before the tensor is ready, then a callback is put in a local table, and will be invoked once the tensor arrives.
+ * If the tensor is ready before the request arives, than the tensor is put in a local table. When the request arrives, it will invoke the callback immediatly.
+ In code it is done by calling **RecvLocalAsync()**, which receives the tensor's key, step-id, and the callback.
+2. When the callback is invoked, the relevant tensor is removed from the tag matching table. In the case where we need to send the tensor's meta-data, the **RdmaTensorResponse** will store a copy of the tensor until the re-request arrives.
+3. The sending of protocol messages (**RDMA_MESSAGE_TENSOR_REQUEST**, **RDMA_MESSAGE_META_DATA_RESPONSE** and **RDMA_MESSAGE_TENSOR_RE_REQUEST**) is done by the class **RdmaMessageBuffer**. All messages are sent using RDMA writes from/to fixed messages buffers. This implies that we cannot send on a specific channel more than one message at a time. In order to synchronize the messages, the **RdmaMessageBuffer** holds the a local and remote buffer statuses which can be either busy or idle. When a write is issued, both statuses will be changed to busy. When the write-complete event is received, the local status is changed to idle. When the write is received on the remote side, the remote side will parse the message, and return an ACK back to the sending side on which the sending side will update the remote status to idle. When both the local and remote statuses are idle, the next message can be sent.
+5. ACK writes are empty writes (hence they require no buffer) with immediate value 0xFFFFFFFE. Message writes have the immediate value 0xFFFFFFFF. All other writes are tensor-content writes whose immediate value is the request-index.
+
+### RDMA components
+
+* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index.
+* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message and tensor writes.
+* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data.
+* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size).
+* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions.
+* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API:
+ * **Start()** - Start the request sequence.
+ * Allocate the result tensor (and proxy tensor if required).
+ * Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side.
+ * **RecvTensorMetaData()** - Receive meta-data from the remote side.
+ * Update the local meta-data cache.
+ * Reallocate the result tensor (and proxy tensor if required).
+ * Re-send the request to the remote side.
+ * **RecvTensorContent()** - Receive tensor content from the remote side (RDMA write was completed).
+ * Decode proto if required and/or move to GPU if the content was not written to it directly (GPU direct is not avaliable).
+ * Invoke the done callback.
+* **class RdmaTensorResponse** - Holds and manages information for a single tensor response throughout the entire send cycle. API:
+ * **Start()** - Start the response sequence.
+ * Find the tensor in the local tag-match table.
+ * Compare the tensor's meta-data to the meta-data in the message (taken from the requester's local cache).
+ * If meta-data changed:
+ * Clone the tensor to be sent later.
+ * Send a meta-data update message and wait for re-request.
+ * Else:
+ * Send the tensor's content (using direct RDMA write).
+ * **Resume()** - Resume the response sequence after a re-request. Send the tensor's content that was cloned earlier.
+ * **Destroy()** - Destroy the response's resources and remove it form the pending list.
+* **class RdmaAdapter** - The base for RDMA communications. It may contain multiple channels and buffers. It is responsible for handling various incoming RDMA messages.
+* **class RdmaChannel** - Responsible for RDMA connection to a particular node. It manages messagee buffers. A channel has a request table which stores all the pending tensor requests.
+* **class RdmaMessageBuffer** - Responsible for sending or receiving messages. It has a fixed size memory to store the data. It has a queue to store the pending jobs. A channel has two message buffers one for tx and one for rx.
+* **class RdmaMgr** - Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
+* **class RdmaRendezvousMgr** - Manages multiple rdma rendezvous.
+* **class RdmaRemoteRendezvous** - A derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
+
+### Message structure:
+
+| type | name_size | name | step_id | request_index | remote_addr/checksum | rkey | is_dead | data_type | tensor_shape | tensor_bytes | error_status |
+|------|---------- |------|---------|---------------|----------------------|------|---------|-----------|--------------|--------------|-----------------------|
+| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B | Size - 4B, proto - XB |
+
+* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request.
+ * type - The message type.
+ * name (name_size) - Name of the requested tensor.
+ * step_id - Step ID.
+ * request_index - Request index.
+ * remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request.
+ * is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating.
+* **RDMA_MESSAGE_META_DATA_RESPONSE** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested).
+ * type - The message type.
+ * request_index - Request index.
+ * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data.
+ * checksum - In data validation mode, this will hold the checksum of the source tensor.
+* **RDMA_MESSAGE_TENSOR_RE_REQUEST** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors.
+ * type - The message type.
+ * name (name_size) - Name of the requested tensor.
+ * step_id - Step ID.
+ * request_index - Request index.
+ * remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor.
+* **RDMA_MESSAGE_ERROR_STATUS** - (sender ==> receiver) Notify the receiver that an error had occured on the sender side, so it can propagate it to the upper levels.
+ * type - The message type.
+ * name (name_size) - Name of the requested tensor.
+ * step_id - Step ID.
+ * request_index - Request index.
+ * error_status - The error status (code, message, details).
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h
index 358977f925..2cfaa4986c 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_client.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#define TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
@@ -47,4 +47,4 @@ class GrpcVerbsClient {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
index f2af6b79fb..742f946c95 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service.cc
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -122,17 +122,15 @@ Status GrpcVerbsService::GetRemoteAddressSync(
rc->SetRemoteAddress(ra, false);
rc->Connect();
int i = 0;
- int idx[] = {1, 0, 3, 2};
- std::vector<RdmaBuffer*> mb(rc->message_buffers());
- CHECK_EQ(request->mr_size(), 4);
+ int idx[] = {1, 0};
+ std::vector<RdmaMessageBuffer*> mb(rc->message_buffers());
+ CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers);
for (const auto& mr : request->mr()) {
// the connections are crossed, i.e.
// local tx_message_buffer <---> remote rx_message_buffer_
// local rx_message_buffer <---> remote tx_message_buffer_
- // local tx_ack_buffer <---> remote rx_ack_buffer_
- // local rx_ack_buffer <---> remote tx_ack_buffer_
- // hence idx[] = {1, 0, 3, 2}.
- RdmaBuffer* rb = mb[idx[i]];
+ // hence idx[] = {1, 0}.
+ RdmaMessageBuffer* rb = mb[idx[i]];
RemoteMR rmr;
rmr.remote_addr = mr.remote_addr();
rmr.rkey = mr.rkey();
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h
index aa509602b5..444c863b94 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
#ifdef TENSORFLOW_USE_VERBS
@@ -69,4 +69,4 @@ void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
index 86431ca030..1f0f10517e 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
#include "grpc++/impl/codegen/async_stream.h"
#include "grpc++/impl/codegen/async_unary_call.h"
@@ -86,4 +86,4 @@ class VerbsService GRPC_FINAL {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#endif // TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
diff --git a/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md
new file mode 100644
index 0000000000..956b8f2147
--- /dev/null
+++ b/tensorflow/contrib/verbs/patch_notes_verbs_with_0_copies.md
@@ -0,0 +1,87 @@
+## Verbs implementation to use direct tensor writes (0 copies)
+
+### Motivation:
+
+Following HKUST research on the use of GPU direct, and their [GDR implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gdr/README.md), we wish to adopt the 0 copies approach and apply it to the current verbs implementation, while keeping the current implementation advantages, such as configurability and the use of RDMA for control messages.
+
+### Performance:
+
+Compared with the current GRPC, verbs and GDR implementation, the result implementation gave the best performance for every model, with any number of nodes. For VGG16 on 8 nodes with 4 P100 GPUs each, the prototype beat the second place by over 15%.
+
+### Implementation requirements:
+
+1. Tensor writes need to be done directly from the source Tensor to the destination Tensor, with no memory copies in between. This should be done for all DMAble tensors which are located either on CPU or on a RDMA compatible GPU device (GPU direct).
+2. Non DMAble tensors (CanMemCopy == false) will be serialized to proto on the sender side, RDMA written to a registered buffer on the receiver side, and then deserialized by the receiver.
+3. Tensors which are located on a non-RDMA-compatible GPU, will be RDMA written to a registered CPU proxy buffer on the receiver side, and then copied to GPU by the receiver.
+
+### Implementation constrains:
+
+For best stability and proof of correctness, we will divide the implementation to two stages:
+1. At first stage we will keep changes to the current implementation to the minimum possible. The expense will be that we may have unused or unnecessary code leftovers, which may also affect performance.
+2. At second stage, we will re-iterate over the code and remove irrelevant code parts.
+The design of the solution aims that we will achieve both stages with relative ease.
+
+### Design guidelines:
+
+1. Since we do not want to do any unnecessary memory copying, we will no longer allocate a fixed CPU buffer as the destination for the RDMA write. Instead we will do the writing directly to the result tensor, or if the result tensor is on a device which does not support RDMA, we will do the writing to a proxy CPU tensor and then copy its content to the result tensor.
+2. The address of the destination Tensor needs to be sent to the sender side for writing, meaning that the result/proxy tensor should be pre-allocated on the receiver side, prior to sending the tensor request. In order to do that, we need to know its meta-data, i.e. shape and data-type for DMAble tensors, and proto-size for serialized tensors. Unfortunately, this information is only available on the sender side which complicates manners. In order to avoid sending extra messages for querying the meta-data on each step, we store a local meta-data cache per tensor. Based on the assumption that the meta-data of a tensor rarely changes between steps, we expect that on most times the cache will only be updated once. When the sender receives a request for a tensor, if it is the first time this tensor is requested, or in the rare case that the meta-data did change, the sender will first send a meta-data response, on which the receiver will update the local cache, and reallocate the result/proxy tensors if required. When the receiver sends the tensor request, it will contain also the meta-data currently stored in its local cache, so the sender can compare it to see if there was a change.
+3. When the sender writes the tensor content to the result tensor, no additional data is being written with it. That means we need to reside on ibverbs immediate (uint32_t) to indicate which request we are responding to (in order to trigger the receive callback). The easiest and most elegant way is to key the recv callback with a unique request_index (uint32_t), instead of the current key_with_step_id (string).
+4. Since the sender no longer writes the tensor from/to fixed buffers, we no longer need to schedule the writes using the local/remote status. In addition we no longer rely on the RmdaTensorBuffer members as the source/destination addresses and rkey/lkey. Instead, each RdmaTensorBuffer will hold multiple "Response" objects (one per step-id), from which we derive destination address and rkey. The source address and lkey are always the ones of the source Tensor.
+5. With the addition of tensor pre-allocation, we noticed there is a large code similarity between sending the first tensor request and re-sending the request in case of meta-data changes. After implementing a common method for tensor pre-allocation, it turned out that implementation becomes much simpler by encapsulating the process of request sending/re-sending, meta-data response callback and content response callback, all in a single "Request" class. The request class holds all the relevant request information, which reduces excessive parameter passing and lambda capturing. This decision is purely for elegance and code simplicity, and we decided to implement it in first stage because it makes the implementation much easier.
+
+### New types/classes:
+
+* **enum RdmaImmDataType** - Immediate types to distinguish between different RDMA writes on the remote side. Ack writes and control-message writes have a fixed immediate value. The rest of the writes are tensor writes and the immediate value is the relevant request index.
+* **enum RdmaWriteIDType** - Types to distinguish between different RDMA write-complete events: Ack, control message, tensor DMA write and tensor proto write.
+* **class RdmaWriteID** - Context for RDMA write complete events. Holds the RdmaWriteIDType and additional data.
+* **class RemoteAddressContext** - Remote address information (address + mr). Will be passed as write context for tensor proto writes.
+* **class RdmaTensorMetaData** - Meta-data for a tensor (type, shape, is_dead, proto_size).
+* **class RdmaMemoryMgr** - Manages the meta-data cache, and the registered memory regions.
+* **class RdmaTensorRequest** - Holds and manages information for a single tensor request throughout the entire receive cycle. API:
+ * Start() - Start the request.
+ * RecvTensorMetaData() - Receive meta-data from the remote side.
+ * RecvTensorContent() - Receive tensor content from the remote side and invoke the done() callback.
+* **class RdmaTensorResponse** - Holds information for a single tensor response, such as destination address and rkey.
+
+### Protocol changes:
+
+The protocol messages themselves will remain mostly unchanged at the first stage, but will be used differently, as described below. The current messages structures already have most of the required fields for the new implementation. The only change is the "buffer_size" field which is no longer used since we are no longer sending additional information with the tensor, and thus it is now always equal to the "tensor_bytes" field. Instead, we use that field to pass the "request_index".
+
+### Message structure:
+
+| type | name_size | name | step_id | request_index | remote_addr | rkey | is_dead | data_type | tensor_shape | tensor_bytes |
+|------|---------- |------|---------|---------------|-------------|------|---------|-----------|--------------|--------------|
+| 1B | 2B | 512 | 8B | 8B | 8B | 4B | 1B | XB | XB | 8B |
+
+* **RDMA_MESSAGE_TENSOR_REQUEST** - (receiver ==> sender) The original tensor request.
+ * type - The message type.
+ * name (name_size) - Name of the requested tensor.
+ * step_id - Step ID.
+ * request_index - Request index.
+ * remote_addr/rkey - Address/rkey of the result/proxy tensor. Irrelevant for first-time request.
+ * is_dead/data_type/tensor_shape/tensor_bytes - The current meta-data as stored in the receiver local cache. The sender will use that information to know if the receiver's cache requires updating.
+* **RDMA_MESSAGE_BUFFER_REQUEST** - (sender ==> receiver) The meta-data update message in case meta-data had changed (or if it is the first time the tensor is requested).
+ * type - The message type.
+ * request_index - Request index.
+ * is_dead/data_type/tensor_shape/tensor_bytes - The up-to-date meta-data.
+* **RDMA_MESSAGE_BUFFER_RESPONSE** - (receiver ==> sender) Tensor re-requset after meta-data update and reallocation of result/proxy tensors.
+ * type - The message type.
+ * name (name_size) - Name of the requested tensor.
+ * step_id - Step ID.
+ * request_index - Request index.
+ * remote_addr/rkey - Address/rkey of the reallocated result/proxy tensor.
+ * is_dead/data_type/tensor_shape/tensor_bytes - The new meta-data. Will be removed in the next phase.
+* **RDMA_MESSAGE_TENSOR_WRITE** - (sender ==> receiver) No longer sent. There is only a direct write of the tensor content to the result/proxy tensor. Request index passed as the immediate value of the write.
+* **RDMA_MESSAGE_TENSOR_IDLE** - (receiver ==> sender) No longer sent.
+
+![alt text](verbs_with_0_copies_phase1_protocol.jpg "Phase 1 message protocol")
+
+### Second stage optimizations:
+1. Remove unused code leftovers.
+2. Remove the ACK buffer completely, since we can rely completely on its immediate value.
+
+### Future optimizations:
+1. Map the tensor names to indexes, to significantly reduce the request message size.
+2. Understand the purpose of empty tensors and if we can skip remote fetching for them.
+3. Consider concatenating multiple requests and/or using multiple message buffers.
+4. Consider a no-request architecture.
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index ae9a384565..86350a08e5 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -15,58 +15,48 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
-#include "tensorflow/contrib/verbs/rdma.h"
#include <fcntl.h>
#include <cstdlib>
-#include <fcntl.h>
-#include "tensorflow/contrib/verbs/verbs_util.h"
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
#define RoCE_V2 "RoCE v2"
namespace {
-// hash name to 32-bit integer
-uint32_t NameHash(const string& name) {
- return Hash32(name.data(), name.size(), 0x1234ABCD);
-}
// convenience function for printing message
string MessageTypeToString(RdmaMessageType rmt) {
switch (rmt) {
- case RDMA_MESSAGE_ACK:
- return "RDMA_MESSAGE_ACK";
- break;
- case RDMA_MESSAGE_BUFFER_IDLE:
- return "RDMA_MESSAGE_BUFFER_IDLE";
+ case RDMA_MESSAGE_META_DATA_UPDATE:
+ return "RDMA_MESSAGE_META_DATA_UPDATE";
break;
- case RDMA_MESSAGE_BUFFER_REQUEST:
- return "RDMA_MESSAGE_BUFFER_REQUEST";
- break;
- case RDMA_MESSAGE_BUFFER_RESPONSE:
- return "RDMA_MESSAGE_BUFFER_RESPONSE";
+ case RDMA_MESSAGE_TENSOR_RE_REQUEST:
+ return "RDMA_MESSAGE_TENSOR_RE_REQUEST";
break;
case RDMA_MESSAGE_TENSOR_REQUEST:
return "RDMA_MESSAGE_TENSOR_REQUEST";
break;
- case RDMA_MESSAGE_TENSOR_WRITE:
- return "RDMA_MESSAGE_TENSOR_WRITE";
- break;
default:
return "UNKNOWN MESSAGE";
}
@@ -347,7 +337,7 @@ uint32_t set_param(uint32_t default_val, const char* env_param) {
enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
ibv_port_attr port_attr;
- enum ibv_mtu mtu;
+ enum ibv_mtu mtu = IBV_MTU_512;
string mtu_s;
int rc, mtu_i;
@@ -459,106 +449,79 @@ void RdmaAdapter::Process_CQ() {
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS)
- << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " "
- << wc_[i].status << " " << static_cast<int>(wc_[i].wr_id) << " "
- << wc_[i].vendor_err;
+ << "Failed status \n"
+ << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+ << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
rc->Recv();
// imm_data is the index of RX buffer in the buffer table.
uint32_t imm_data = wc_[i].imm_data;
- RdmaBuffer* rb = rc->FindBuffer(imm_data);
+ RdmaMessageBuffer* rb;
RdmaMessage rm;
- RdmaMessage::ParseMessage(rm, rb->buffer_);
- VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
- if (rm.type_ == RDMA_MESSAGE_ACK) {
+ if (imm_data == RDMA_IMM_DATA_ACK) {
// receive an ack to a message
rb = rc->tx_message_buffer_;
rb->SetBufferStatus(remote, idle);
rb->SendNextItem();
- } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
- // received a request-for-tensor message
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find or create buffer
- RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
- string key_with_step_id =
- VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
- tb->EnqueueItem(key_with_step_id);
- // send the next tensor
- worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
- // receive tensor-buffer-ready message
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find buffer
- RdmaTensorBuffer* tb =
- reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
- tb->SetBufferStatus(remote, idle);
- worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
- // remote host requests to create a tensor buffer;
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find or create the buffer
- RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
- RemoteMR rmr;
- rmr.remote_addr = rm.remote_addr_;
- rmr.rkey = rm.rkey_;
- tb->SetRemoteMR(rmr, true);
- tb->CreateCPUBuffer(rm.buffer_size_);
- // create RDMA_MESSAGE_BUFFER_RESPONSE message
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
- br.name_size_ = rm.name_.size();
- br.name_ = rm.name_;
- br.buffer_size_ = rm.buffer_size_;
- br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
- br.rkey_ = tb->self_->rkey;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* mb = rc->tx_message_buffer_;
- mb->EnqueueItem(message);
- mb->SendNextItem();
- } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
- // remote creates a buffer and responds
- // send ack to release remote tx message buffer
- RdmaBuffer* ab = rc->tx_ack_buffer_;
- ab->SendNextItem();
- // find buffer
- RdmaTensorBuffer* tb =
- reinterpret_cast<RdmaTensorBuffer*>(rc->FindBuffer(rm.name_));
- CHECK(rm.buffer_size_ == tb->size_)
- << "rm.buffer_size = " << rm.buffer_size_
- << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
- RemoteMR rmr;
- rmr.remote_addr = rm.remote_addr_;
- rmr.rkey = rm.rkey_;
- tb->SetRemoteMR(rmr, true);
- tb->SetBufferStatus(local, idle);
- tb->SetBufferStatus(remote, idle);
- worker_env_->compute_pool->Schedule([tb]() { tb->ReSendNextItem(); });
- } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
- // tensor RDMA write completed
- worker_env_->compute_pool->Schedule([rm, rc]() {
- string key_with_step_id =
- VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
- rc->RunRecvCallback(key_with_step_id);
- });
+ continue;
}
- } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
- RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
- rb->SetBufferStatus(local, idle);
- RdmaMessage rm;
+
+ if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
+ // receive a tensor RDMA write
+ uint32_t request_index = imm_data;
+ RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
+ request->RecvTensorContent();
+ continue;
+ }
+
+ // receive a control message
+ rb = rc->rx_message_buffer_;
RdmaMessage::ParseMessage(rm, rb->buffer_);
- VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
- if (rm.type_ != RDMA_MESSAGE_ACK) {
- worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
+ RdmaMessageBuffer::SendAck(rc);
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Received " << MessageTypeToString(rm.type_) << " "
+ << "#" << rm.request_index_ << ": " << rm.name_;
+
+ if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
+ RdmaTensorResponse* response = rc->AddTensorResponse(rm);
+ response->Start();
+ } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
+ RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
+ request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
+ rm.is_dead_, rm.tensor_bytes_);
+#ifdef RDMA_DATA_VALIDATION
+ request->RecvTensorChecksum(rm.checksum_);
+#endif
+ } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
+ RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
+ response->Resume();
+ } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
+ request->RecvErrorStatus(rm.status_);
}
+ } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
+ RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
+ RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
+ switch (wr_id->write_type) {
+ case RDMA_WRITE_ID_ACK:
+ break;
+ case RDMA_WRITE_ID_MESSAGE: {
+ RdmaMessageBuffer* rb =
+ reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
+ rb->SetBufferStatus(local, idle);
+ rb->SendNextItem();
+ break;
+ }
+ case RDMA_WRITE_ID_TENSOR_WRITE: {
+ RdmaTensorResponse* response =
+ reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
+ response->Destroy();
+ }
+ }
+ delete wr_id;
}
}
}
@@ -577,7 +540,7 @@ int RdmaChannel::PingPostRecv() {
int RdmaChannel::PingPostSend() {
struct ibv_send_wr wr, *bad_wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
wr.sg_list = &ping_sge_list_;
wr.num_sge = 1;
wr.opcode = IBV_WR_SEND;
@@ -588,8 +551,10 @@ int RdmaChannel::PingPostSend() {
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name)
- : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
-
+ : adapter_(adapter),
+ local_name_(local_name),
+ remote_name_(remote_name),
+ request_serial_(0) {
struct ibv_sge list;
mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
@@ -651,29 +616,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
// create message and ack buffers, then initialize the tables.
{
- const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
- "tx_ack_buffer", "rx_ack_buffer"};
+ const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"};
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
- tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
- rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
message_buffers_.reserve(kNumMessageBuffers);
message_buffers_.push_back(tx_message_buffer_);
message_buffers_.push_back(rx_message_buffer_);
- message_buffers_.push_back(tx_ack_buffer_);
- message_buffers_.push_back(rx_ack_buffer_);
// create buffer on host
tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
- tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
- rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
- // bt_mu_.lock() is not used in constructor.
- for (int i = 0; i < kNumMessageBuffers; i++) {
- uint32_t index = NameHash(buffer_names[i]);
- buffer_table_.insert({index, message_buffers_[i]});
- buffer_index_name_table_.insert({index, buffer_names[i]});
- buffer_name_index_table_.insert({buffer_names[i], index});
- }
}
CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
<< " with error " << std::strerror(errno);
@@ -684,8 +635,6 @@ RdmaChannel::~RdmaChannel() {
CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
delete tx_message_buffer_;
delete rx_message_buffer_;
- delete tx_ack_buffer_;
- delete rx_ack_buffer_;
}
void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
@@ -711,119 +660,36 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t)this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
-// Lookup 32-bit buffer index from buffer name
-// Args:
-// buffer_name: name of the buffer
-// Returns:
-// 32-bit index
-uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
- mutex_lock lock{bt_mu_};
- BufferNameIndexTable::iterator iter =
- buffer_name_index_table_.find(buffer_name);
- CHECK(iter != buffer_name_index_table_.end());
- return iter->second;
-}
-
-// Find a buffer by its 32-bit index
-// Args:
-// index: 32-bit hash code of the tensor buffer name
-// Returns:
-// name of the tensor buffer
-RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
- mutex_lock lock{bt_mu_};
- BufferTable::iterator iter = buffer_table_.find(index);
- CHECK(iter != buffer_table_.end());
- return iter->second;
-}
-
-// Find a buffer by its name
-// Args:
-// name: name of the buffer
-// Returns:
-// the named rdma buffer
-RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
- uint32_t index = LookupBufferIndex(name);
- return FindBuffer(index);
-}
-
-// Find a buffer if it exists, otherwise create one.
-// The memory inside the created buffer is not allocated.
-// Args:
-// name: the name of the buffer
-// buffer_type: TENSOR, MESSAGE or ACK.
-// Returns:
-// the named buffer
-RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
- BufferType buffer_type) {
- mutex_lock lock{bt_mu_};
- RdmaBuffer* rb;
- // find index
- BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
- if (iter != buffer_name_index_table_.end()) {
- uint32_t index = iter->second;
- // find buffer
- BufferTable::iterator iter = buffer_table_.find(index);
- CHECK(iter != buffer_table_.end());
- rb = iter->second;
- } else {
- uint32_t index = NameHash(name);
- if (buffer_type == TENSOR) {
- rb = new RdmaTensorBuffer(this, name);
- } else if (buffer_type == MESSAGE) {
- rb = new RdmaMessageBuffer(this, name);
- } else if (buffer_type == ACK) {
- rb = new RdmaAckBuffer(this, name);
- }
- buffer_name_index_table_.insert({name, index});
- buffer_index_name_table_.insert({index, name});
- buffer_table_.insert({index, rb});
+RdmaTensorRequest* RdmaChannel::InsertTensorRequest(
+ const string& key, int64 step_id, Device* dst_dev,
+ const Rendezvous::Args recv_args,
+ const RdmaTensorRequest::RecvDoneCallback& done) {
+ mutex_lock lock{ct_mu_};
+ uint32_t request_index = request_serial_++;
+ if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) {
+ request_serial_ = 0;
}
- CHECK(rb);
- return rb;
+ RdmaTensorRequest request(request_index, key, step_id, this, dst_dev,
+ recv_args, done);
+ auto it = request_table_.emplace(request_index, request);
+ return &it.first->second;
}
-// Insert callback to the callback_table.
-// The callback is activated when the corresponding tensor is received.
-// Arg:
-// key: the name of the tensor
-// recv_done: the callback associated with the tensor.
-// Returns:
-// None
-void RdmaChannel::InsertRecvCallback(const string& key,
- std::function<void()> recv_done) {
+void RdmaChannel::RemoveTensorRequest(uint32_t request_index) {
mutex_lock lock{ct_mu_};
- callback_table_.insert({key, recv_done});
+ request_table_.erase(request_index);
}
-// Remove callback from the callback_table.
-// Arg:
-// key: the name of the tensor
-// Returns:
-// None
-void RdmaChannel::RemoveRecvCallback(const string& key) {
+RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) {
mutex_lock lock{ct_mu_};
- callback_table_.erase(key);
-}
-
-// Run named callback in the callback_table.
-// Arg:
-// key: the name of the tensor
-// Returns:
-// None
-void RdmaChannel::RunRecvCallback(const string& key) {
- std::function<void()> recv_done;
- {
- mutex_lock lock{ct_mu_};
- CallbackTable::iterator iter = callback_table_.find(key);
- CHECK(iter != callback_table_.end());
- recv_done = iter->second;
- }
- recv_done();
+ RequestTable::iterator iter = request_table_.find(request_index);
+ CHECK(iter != request_table_.end());
+ return &iter->second;
}
void RdmaChannel::Connect() {
@@ -865,11 +731,11 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
int r;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV |
- IBV_QP_PATH_MTU |
- IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
- IBV_QP_MAX_DEST_RD_ATOMIC |
- IBV_QP_MIN_RNR_TIMER)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+ IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+ IBV_QP_MAX_DEST_RD_ATOMIC |
+ IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
@@ -880,33 +746,30 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
- CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT |
- IBV_QP_RETRY_CNT |
- IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
- IBV_QP_MAX_QP_RD_ATOMIC)))
+ CHECK(!(r = ibv_modify_qp(qp_, &attr,
+ IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+ IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+ IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
} else {
- LOG(INFO) << "channel already connected";
+ RDMA_LOG(2) << "channel already connected";
}
}
-RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
+RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
: channel_(channel), name_(name) {}
-RdmaBuffer::~RdmaBuffer() {
+RdmaMessageBuffer::~RdmaMessageBuffer() {
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
FreeBuffer();
}
-void RdmaBuffer::FreeBuffer() {
+void RdmaMessageBuffer::FreeBuffer() {
if ((buffer_ != nullptr) && buffer_on_host_) {
free(buffer_);
}
- // TODO
- // release buffer if it is on device.
- // We don't support RDMABuffer on device at this moment.
}
// Allocate CPU memory for the Rdma buffer
@@ -915,7 +778,7 @@ void RdmaBuffer::FreeBuffer() {
// lock: whether or not mutex_lock the process to protect concurrency.
// Returns:
// None
-void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
+void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) {
CHECK(size > 0);
if (lock) {
mu_.lock();
@@ -943,7 +806,7 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
// override: whether override existing information
// Returns:
// None
-void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
+void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
mutex_lock lock{mu_};
if ((override) || (remote_status_ == none)) {
remote_.remote_addr = rmr.remote_addr;
@@ -956,63 +819,51 @@ void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
}
// Put a task in the buffer's job queue
-void RdmaBuffer::EnqueueItem(string item) {
+void RdmaMessageBuffer::EnqueueItem(string item) {
mutex_lock lock{mu_};
queue_.push(item);
}
// Rdma-Write the content of the buffer
-void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+ Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey,
+ remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this);
+}
+
+// Generalized Write method
+void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data,
+ size_t buffer_size, uint64_t src_addr,
+ uint32_t lkey, uint64_t remote_addr,
+ uint32_t rkey, RdmaWriteIDType write_type,
+ void* write_context) {
struct ibv_sge list;
- list.addr = (uint64_t)buffer_;
+ list.addr = src_addr;
list.length = buffer_size;
- list.lkey = self_->lkey;
+ list.lkey = lkey;
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
- wr.wr_id = (uint64_t) this;
+ wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context);
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.send_flags = IBV_SEND_SIGNALED;
wr.imm_data = imm_data;
- wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
- wr.wr.rdma.rkey = remote_.rkey;
+ wr.wr.rdma.remote_addr = remote_addr;
+ wr.wr.rdma.rkey = rkey;
struct ibv_send_wr* bad_wr;
- CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
-}
-
-RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
- : RdmaBuffer(channel, name) {}
-
-RdmaTensorBuffer::~RdmaTensorBuffer() {
- for (Itable it = retable.begin(); it != retable.end(); ++it) {
- delete (it->second);
- }
+ CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send";
}
// Send the next ack from the buffer's job queue.
-void RdmaAckBuffer::SendNextItem() {
- uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
- RdmaMessage rm;
- rm.name_ = "rx_ack_buffer";
- rm.type_ = RDMA_MESSAGE_ACK;
- rm.name_size_ = rm.name_.size();
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- Write(imm_data, message.size());
+void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) {
+ Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr);
}
// Send the next message from the buffer's job queue.
void RdmaMessageBuffer::SendNextItem() {
- uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
+ uint32_t imm_data = RDMA_IMM_DATA_MESSAGE;
mu_.lock();
if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
local_status_ = busy;
@@ -1029,244 +880,390 @@ void RdmaMessageBuffer::SendNextItem() {
}
}
-Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback(
- const string& key_with_step_id, const string& key, int64 step_id,
- const Rendezvous::ParsedKey& parsed) {
- Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id, parsed](
- const Status& status, const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
- CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
- << " error message: " << status.error_message();
- size_t buffer_size = RdmaMessage::kMessageTotalBytes;
- size_t tensor_bytes = 0;
- // Figures out which device the tensor is hosted on.
- Device* src_dev = nullptr;
- Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
- parsed.src_device, &src_dev);
- CHECK(s.ok()) << "src device not found";
- // Does the device have the right incarnation number we expect?
- CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
- << "RecvTensor expects a different device incarnation: "
- << parsed.src_incarnation << " vs. "
- << src_dev->attributes().incarnation()
- << ". Your worker job was probably restarted. Check your "
- << "worker job for the reason why it was restarted.";
- Device* dst_dev = nullptr;
- // destination is on CPU.
- s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
- &dst_dev);
- CHECK(s.ok()) << "dst device not found";
- AllocatorAttributes dst_alloc_attr;
- dst_alloc_attr.set_on_host(true);
-
- bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
- // string tensor needs to be serialized
- Tensor copy;
- TensorProto proto;
- if (src_dev->tensorflow_gpu_device_info() &&
- (!send_args.alloc_attrs.on_host())) {
#if GOOGLE_CUDA
- CHECK(send_args.device_context) << "send dev name: " << src_dev->name()
- << " gpu_info: "
- << src_dev->tensorflow_gpu_device_info();
-
- if (can_memcpy) {
- AllocatorAttributes host_alloc_attrs;
- host_alloc_attrs.set_gpu_compatible(true);
- host_alloc_attrs.set_on_host(true);
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
- copy = Tensor(alloc, in.dtype(), in.shape());
- tensor_bytes = in.TotalBytes();
- buffer_size += tensor_bytes;
- GPUUtil::CopyGPUTensorToCPU(
- src_dev, send_args.device_context, &in, &copy,
- [this, copy, tensor_bytes, buffer_size, key, in, step_id,
- key_with_step_id, is_dead, send_args, recv_args](const Status& s) {
- CHECK(s.ok()) << "copy tensor from gpu sync";
- StringPiece copy_buf;
- copy_buf = copy.tensor_data();
- PostCopyOperations(true, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, &copy,
- NULL, &copy_buf, send_args, recv_args);
- });
- } else {
- // "val" is on a GPU. No longer uses GPUUtil to fill the proto, use
- // aync instead
- GPUUtil::SetProtoFromGPU(
- in, src_dev, send_args.device_context, &proto, is_dead,
- [this, proto, buffer_size, key, in, step_id, key_with_step_id,
- is_dead, send_args, recv_args](const Status& s) mutable {
- CHECK(s.ok()) << "copy proto from gpu sync";
- auto tensor_bytes = proto.ByteSize();
- buffer_size += tensor_bytes;
- PostCopyOperations(false, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, NULL,
- &proto, NULL, send_args, recv_args);
- });
- }
+static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
+ size_t tensor_bytes, bool is_gpu_to_cpu) {
+#ifdef RDMA_COUNT_COPIES
+ static uint64_t numGPUToCPUCopies = 0;
+ static uint64_t numGPUToCPUCopiedBytes = 0;
+ static uint64_t numCPUToGPUCopies = 0;
+ static uint64_t numCPUToGPUCopiedBytes = 0;
+ static uint64_t numTotalCopies = 0;
+
+ if (is_gpu_to_cpu) {
+ ++numGPUToCPUCopies;
+ numGPUToCPUCopiedBytes += tensor_bytes;
+ } else {
+ ++numCPUToGPUCopies;
+ numCPUToGPUCopiedBytes += tensor_bytes;
+ }
+ if ((++numTotalCopies % 0x400) == 0) {
+ RDMA_LOG(0) << "Tensor copies:"
+ << " GPU to CPU: " << numGPUToCPUCopies << " ("
+ << numGPUToCPUCopiedBytes << " Bytes)"
+ << " CPU to GPU: " << numCPUToGPUCopies << " ("
+ << numCPUToGPUCopiedBytes << " Bytes)";
+ }
+ RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
+ << " To: " << dst_addr;
+#endif // RDMA_COUNT_COPIES
+}
#endif // GOOGLE_CUDA
- } else {
- // tensor is in CPU memory.
- StringPiece copy_buf;
- if (can_memcpy) {
- copy_buf = in.tensor_data();
- tensor_bytes = in.TotalBytes();
- } else {
- in.AsProtoTensorContent(&proto);
- tensor_bytes = proto.ByteSize();
- }
- buffer_size += tensor_bytes;
- PostCopyOperations(can_memcpy, buffer_size, tensor_bytes, key, in,
- step_id, is_dead, key_with_step_id, &copy, &proto,
- &copy_buf, send_args, recv_args);
+
+#ifdef RDMA_DATA_VALIDATION
+static uint64_t Checksum(Device* device, const DeviceContext* device_context,
+ const Tensor& in) {
+ uint64 checksum = 0;
+ if (DataTypeCanUseMemcpy(in.dtype())) {
+#if GOOGLE_CUDA
+ if (in.TotalBytes() == 0) {
+ return 0;
}
- };
- return cb;
+ checksum = (device_context != nullptr)
+ ? GPUUtil::Checksum(device, device_context, in)
+ : GPUUtil::Checksum(in);
+#endif // GOOGLE_CUDA
+ } else {
+ string s = in.SummarizeValue(999999);
+ checksum = Hash64(s.c_str(), s.size(), 0);
+ }
+ return checksum;
}
-// Send the next tensor from the buffer's job queue.
-void RdmaTensorBuffer::SendNextItem() {
- // get the key
- string key_with_step_id = "";
- {
- mutex_lock lock{mu_};
- if (!queue_.empty()) {
- key_with_step_id = queue_.front();
- queue_.pop();
+static void ValidateChecksum(uint64_t expected, uint64_t actual,
+ const Tensor& in, uint32_t request_index,
+ const std::string& key, const std::string& msg) {
+ RDMA_LOG(2) << "Request #" << request_index << ": " << key
+ << ": Checksum: " << std::hex << " Expected = 0x" << expected
+ << ". Actual = 0x" << actual << ".";
+
+ if (expected != actual) {
+ // Checksum failed. There is one case where this is allowed - if the
+ // tensor is an AssignAdd of the global step. Since the data-validation
+ // always postpones the Tensor response in order to send a checksum message,
+ // it is possible that the global-step was updated while the response was
+ // still in queue.
+ if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) {
+ int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1;
+ actual = Hash64((const char*)&prev_val, 8, 0);
+ }
+ if (expected != actual) {
+ LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #"
+ << request_index << ": " << key << std::hex << " "
+ << DataTypeString(in.dtype()) << " "
+ << in.shape().DebugString() << " (0x" << in.TotalBytes()
+ << " bytes): "
+ << " Expected 0x" << expected << ". Got 0x" << actual << ".";
}
}
+}
+#endif // RDMA_DATA_VALIDATION
+
+#if GOOGLE_CUDA
+// Sync the 'done' operation on the GPU stream, but without all the data
+// copying.
+static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
+ StatusCallback done) {
+ Tensor dummy1, dummy2;
+ GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
+ done);
+}
+#endif // GOOGLE_CUDA
+
+RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) {
+ mutex_lock lock{mu_};
+ auto it =
+ responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm));
+ CHECK(it.second) << "Response with the ID " << rm.request_index_
+ << " already exists.";
+ return &it.first->second;
+}
+
+RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) {
+ mutex_lock lock{mu_};
+ auto it = responses_table_.find(rm.request_index_);
+ CHECK(it != responses_table_.end()) << "No response found.";
+ RdmaTensorResponse* response = &it->second;
+ response->Update(rm);
+ return response;
+}
+
+void RdmaChannel::RemoveTensorResponse(uint32_t request_index) {
+ mutex_lock lock{mu_};
+ responses_table_.erase(request_index);
+}
+
+void RdmaTensorResponse::Start() {
+ Rendezvous::ParsedKey parsed;
+ Status s = Rendezvous::ParseKey(rm_.name_, &parsed);
+ if (!s.ok()) {
+ SendErrorStatus(s);
+ return;
+ }
- // send the tensor if a key is acquired.
- if (key_with_step_id != "") {
- VLOG(2) << "try to send tensor: " << key_with_step_id;
- string key;
- int64 step_id;
- VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
- CHECK(key.compare(name_) == 0);
- Rendezvous::ParsedKey parsed;
- Rendezvous::ParseKey(key, &parsed);
- Rendezvous::DoneCallback cb =
- getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
- channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(step_id,
- parsed, cb);
+ channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(
+ rm_.step_id_, parsed,
+ [this, parsed](const Status& status, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ bool is_dead) {
+ CHECK(status.ok()) << "RecvLocalAsync was not ok."
+ << " error message: " << status.error_message();
+ RecvHandler(parsed, send_args, recv_args, in, is_dead);
+ });
+}
+
+void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); }
+
+// Helper for RecvTensor. Validates "key" and returns the source
+// device in "*src_dev".
+Status RdmaTensorResponse::PrepareRecvTensor(
+ const Rendezvous::ParsedKey& parsed, Device** src_dev) {
+ // Figures out which device the tensor is hosted on.
+ string local_name = DeviceNameUtils::LocalName(parsed.src_device);
+ TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice(
+ local_name, src_dev));
+
+ // Does the device have the right incarnation number we expect?
+ if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
+ return errors::Aborted(
+ "RecvTensor expects a different device incarnation: ",
+ parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
+ ". Your worker job was probably restarted. Check your "
+ "worker job for the reason why it was restarted.");
}
+
+ return Status::OK();
}
-void RdmaTensorBuffer::ReSendNextItem() {
- // get the key
- string key_with_step_id = "";
- {
- mutex_lock lock{mu_};
- if (!requeue.empty()) {
- key_with_step_id = requeue.front();
- requeue.pop();
- }
+void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ Status s = PrepareRecvTensor(parsed, &src_dev_);
+ if (!s.ok()) {
+ SendErrorStatus(s);
+ return;
}
- // send the tensor if a key is acquired.
- if (key_with_step_id != "") {
- VLOG(2) << "try to send tensor: " << key_with_step_id;
- string key;
- int64 step_id;
- VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
- CHECK(key.compare(name_) == 0);
- Rendezvous::ParsedKey parsed;
- Rendezvous::ParseKey(key, &parsed);
- Rendezvous::DoneCallback cb =
- getRecvTensorCallback(key_with_step_id, key, step_id, parsed);
- ReItem* item;
- {
- mutex_lock lock{mu_};
- Itable it = retable.find(key_with_step_id);
- CHECK(it != retable.end()) << "Could not find dup-recv context";
- item = it->second;
- retable.erase(it);
+ meta_data_changed_ = TensorMetaDataChanged(in, is_dead);
+#ifdef RDMA_DATA_VALIDATION
+ // Always send a meta data message with the source checksum
+ meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST;
+ checksum_ = Checksum(src_dev_, send_args.device_context, in);
+#endif
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ // string tensor needs to be serialized
+ Tensor copy;
+ TensorProto proto;
+ const bool on_host = send_args.alloc_attrs.on_host();
+ if (src_dev_->tensorflow_gpu_device_info() && !on_host) {
+#if GOOGLE_CUDA
+ DeviceContext* send_dev_context = send_args.device_context;
+ CHECK(send_dev_context)
+ << "send dev name: " << src_dev_->name()
+ << " gpu_info: " << src_dev_->tensorflow_gpu_device_info();
+
+ if (can_memcpy) {
+ // If the tensor is located on a GDR compatible GPU, there is no need to
+ // copy it. We can send directly from the source, just need to make sure
+ // we are in sync with the GPU stream.
+ // If the tensor's meta-data changed however, we will need to clone it,
+ // so anyway we'll have to copy it from GPU to CPU first. If at some
+ // point in time Clone() is changed to only save a shallow copy, we can
+ // skip the copy here as well.
+ if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
+ (RdmaMemoryMgr::Singleton().FindMemoryRegion(
+ (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
+ StreamGPUOp(src_dev_, send_dev_context,
+ [this, in, proto, is_dead](const Status& s) {
+ Send(in, proto, is_dead, s);
+ });
+ return;
+ }
+
+ // The tensor must be copied from GPU to CPU, because either:
+ // 1. The tensor is located on a non GDR compatible GPU.
+ // 2. The tensor's meta-data has changed.
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ copy = Tensor(alloc, in.dtype(), in.shape());
+ CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
+ (void*)DMAHelper::base(&copy), in.TotalBytes(), true);
+ GPUUtil::CopyGPUTensorToCPU(
+ src_dev_, send_dev_context, &in, &copy,
+ [this, copy, proto, is_dead](const Status& s) {
+ Send(copy, proto, is_dead, s);
+ });
+ } else {
+ GPUUtil::SetProtoFromGPU(
+ in, src_dev_, send_args.device_context, &proto, is_dead,
+ [this, in, proto, is_dead](const Status& s) mutable {
+ Send(in, proto, is_dead, s);
+ });
+ }
+#else
+ SendErrorStatus(errors::Internal("No GPU device in process"));
+#endif // GOOGLE_CUDA
+ } else {
+ // tensor is in CPU memory.
+ if (!can_memcpy) {
+ in.AsProtoTensorContent(&proto);
}
- cb(Status::OK(), item->send_args, item->recv_args, item->in, item->is_dead);
- delete (item);
+ Send(in, proto, is_dead, Status::OK());
+ }
+}
+
+void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
+ bool is_dead, const Status& status) {
+ if (!status.ok()) {
+ SendErrorStatus(status);
+ return;
+ }
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ bool proto_size_changed =
+ (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
+ if (meta_data_changed_ || proto_size_changed) {
+ Clone(in, proto, is_dead);
+ SendMetaData(in, proto, is_dead);
+ } else {
+ SendContent(in, proto, is_dead);
}
}
-void RdmaTensorBuffer::PostCopyOperations(
- bool can_memcpy, size_t buffer_size, size_t tensor_bytes, const string& key,
- const Tensor& in, int64 step_id, bool is_dead,
- const string& key_with_step_id, const Tensor* copy,
- const TensorProto* proto, const StringPiece* copy_buf,
- const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args) {
- // prepare message
+bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) {
+ return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) ||
+ (rm_.is_dead_ != is_dead);
+}
+
+void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto,
+ bool is_dead) {
+ // Clone the data to be sent later. For simplicity, we clone the tensor's
+ // data even if it is already a copy. Performance is less of a concern here
+ // since the meta-data hardly ever changes. The reason we create a copy, is
+ // that some tensors share their buffer between different step-ids, so the
+ // tensor content may change before re-request was completed.
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ if (can_memcpy && (in.TotalBytes() > 0)) {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_nic_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs);
+ tensor_ = new Tensor(allocator, in.dtype(), in.shape());
+ memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes());
+ } else {
+ tensor_ = new Tensor(in.dtype(), in.shape());
+ }
+ if (!can_memcpy) {
+ proto_ = new TensorProto(proto);
+ }
+ is_dead_ = is_dead;
+}
+
+void RdmaTensorResponse::SendMetaData(const Tensor& in,
+ const TensorProto& proto, bool is_dead) {
+ RDMA_LOG(2) << "Request #" << rm_.request_index_
+ << ": Meta data changed: " << rm_.name_;
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
+
+ // Send meta-data update:
RdmaMessage rm;
- rm.name_size_ = key.size();
- rm.name_ = key;
+ rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE;
+ rm.name_size_ = rm_.name_.size();
+ rm.name_ = rm_.name_;
rm.tensor_shape_ = in.shape();
rm.data_type_ = in.dtype();
- rm.step_id_ = step_id;
+ rm.step_id_ = rm_.step_id_;
rm.is_dead_ = is_dead;
rm.tensor_bytes_ = tensor_bytes;
- rm.buffer_size_ = buffer_size;
- mu_.lock();
- if (local_status_ == none || (buffer_size > size_ && local_status_ == idle &&
- remote_status_ == idle)) {
- if ((local_status_ != none) && (buffer_size > size_)) {
- VLOG(2) << "Extend RDMA buffer from " << size_ << " to " << buffer_size;
- }
- CreateCPUBuffer(buffer_size, false);
- // Need to be received again, put into the re-recv queue and the table
- requeue.push(key_with_step_id);
- ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
- retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
- mu_.unlock();
- // no longer used: put back the key since it is not sent;
- // ask the remote to create the same buffer
- rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
- rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
- rm.rkey_ = self_->rkey;
- string message = RdmaMessage::CreateMessage(rm);
- channel_->tx_message_buffer_->EnqueueItem(message);
- channel_->tx_message_buffer_->SendNextItem();
- } else if ((local_status_ == idle) && (remote_status_ == idle)) {
- // both buffers are ready, send the tensor
- local_status_ = busy;
- remote_status_ = busy;
- // local/remote_status_ won't be set back to idle
- // unitl Write() is successful
- mu_.unlock();
- if (!((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
- (buffer_size <= size_ && rm.data_type_ == DT_STRING))) {
- VLOG(2) << "Tensor and buffer size do not agree,"
- << " buffer_size = " << size_
- << " requested tensor size = " << buffer_size << in.DebugString();
- }
- uint32_t imm_data = LookupBufferIndex(key);
- rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
- string message = RdmaMessage::CreateMessage(rm);
- memcpy(buffer_, message.data(), message.size());
- if (!is_dead) {
- // copy the tensor buffer content
- void* output = static_cast<void*>(static_cast<char*>(buffer_) +
- RdmaMessage::kTensorBufferStartIndex);
- CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- if (can_memcpy) {
- CHECK(copy != NULL) << "callback missing pointer to copy tensor";
- CHECK(copy_buf != NULL) << "callback missing pointer to copy buffer";
- CHECK(copy_buf->size() == tensor_bytes)
- << "unexpected tensor size: " << copy_buf->size()
- << " != " << tensor_bytes;
- memcpy(output, copy_buf->data(), tensor_bytes);
- } else {
- CHECK(proto != NULL) << "callback missing pointer to proto tensor";
- proto->SerializeToArray(output, tensor_bytes);
+ rm.request_index_ = rm_.request_index_;
+#ifdef RDMA_DATA_VALIDATION
+ rm.checksum_ = checksum_;
+#endif
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #"
+ << rm.request_index_ << ": " << rm.name_
+ << " (shape = " << rm.tensor_shape_.DebugString() << "."
+ << " data-type = " << DataTypeString(rm.data_type_) << "."
+ << " is-dead = " << rm.is_dead_ << ")";
+
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+}
+
+void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto,
+ bool is_dead) {
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
+ size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
+ uint32_t imm_data = rm_.request_index_;
+ if (!is_dead) {
+ if (can_memcpy) {
+ src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in));
+ if (src_buffer_ != nullptr) {
+ src_buffer_->Ref(); // Keep buffer alive until write is complete
+ src_addr_ = src_buffer_->data();
+ mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_,
+ tensor_bytes);
}
} else {
- buffer_size = RdmaMessage::kMessageTotalBytes;
+ RDMA_LOG(2) << "Encoding proto: " << rm_.name_
+ << " (Size: " << tensor_bytes << ") " << in.DebugString();
+ src_addr_ = malloc(tensor_bytes);
+ mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ proto.SerializeToArray(src_addr_, tensor_bytes);
}
- Write(imm_data, buffer_size);
} else {
- // Need to be received again, put into the re-recv queue and the table
- requeue.push(key_with_step_id);
- ReItem* item = new ReItem(send_args, recv_args, in, is_dead);
- retable.insert(std::pair<string, ReItem*>(key_with_step_id, item));
- mu_.unlock();
+ tensor_bytes = 0;
+ }
+
+ uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey;
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec
+ << ": Sending tensor content #" << rm_.request_index_ << " from "
+ << std::hex << src_addr_ << " (0x" << lkey << ")"
+ << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_
+ << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes
+ << ")";
+
+ RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes,
+ (uint64_t)src_addr_, lkey, rm_.remote_addr_,
+ rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this);
+}
+
+void RdmaTensorResponse::SendErrorStatus(const Status& status) {
+ RdmaMessage rm;
+ rm.type_ = RDMA_MESSAGE_ERROR_STATUS;
+ rm.name_size_ = rm_.name_.size();
+ rm.name_ = rm_.name_;
+ rm.step_id_ = rm_.step_id_;
+ rm.request_index_ = rm_.request_index_;
+ rm.status_ = status;
+ LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
+ << ": " << rm.name_ << ". Status: " << status.ToString();
+
+ string message = RdmaMessage::CreateMessage(rm);
+ channel_->tx_message_buffer_->EnqueueItem(message);
+ channel_->tx_message_buffer_->SendNextItem();
+
+ // Destroy the response.
+ Destroy();
+}
+
+void RdmaTensorResponse::Destroy() {
+ if (src_buffer_ != nullptr) {
+ src_buffer_->Unref();
+ }
+ if (tensor_ != nullptr) {
+ delete tensor_;
+ }
+ if (proto_ != nullptr) {
+ ibv_dereg_mr(mr_);
+ free(src_addr_);
+ delete proto_;
}
+ // Remove response from the pending list:
+ channel_->RemoveTensorResponse(rm_.request_index_);
}
// Create a RdmaMessage according to the pre-defined format
@@ -1276,43 +1273,46 @@ void RdmaTensorBuffer::PostCopyOperations(
// message in string format
string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
// Rdma Message format
- // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
- // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
- // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
- // ...| XB | XB | 8B |...
+ // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
+ // ...|data_type|tensor_shape|tensor_bytes|error_status |
+ // ...| XB | XB | 8B |size - 4B, proto - XB |
//
- // ACK: type|13|"rx_ack_buffer"
- // TENSOR_REQUEST: type|name_size|tensor_name|step_id
- // TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
- // |data_type|tensor_shape|tensor_bytes
- // BUFFER_IDLE: type|name_size|buffer_name
- // BUFFER_REQUEST:
- // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
- // BUFFER_RESPONSE:
- // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
- char message[kMessageTotalBytes];
+ // ACK: Imm-type: ACK
+ // TENSOR_REQUEST: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, remote_addr,
+ // rkey, is_dead, data_type, tensor_shape, tensor_bytes
+ // META_DATA_UPDATE: Imm-type: MESSAGE
+ // Fields: type, request_index, is_dead, data_type,
+ // tensor_shape, tensor_bytes
+ // TENSOR_RE_REQUST: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, remote_addr,
+ // rkey, is_dead, data_type, tensor_shape, tensor_bytes
+ // ERROR_STATUS: Imm-type: MESSAGE
+ // Fields: type, request_index, name, step_id, error_status
+ // Tensor content: Imm-type: request_index
+ size_t message_size = kMessageTotalBytes;
+ char message[kMessageTotalBytes + kErrorStatusMaxSize];
// type
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
- // size of name
- memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
- // name
- memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
- // buffer_size, remote_addr, rkey
- if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
- (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
- memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
- sizeof(rm.buffer_size_));
+ // request index
+ memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
+ sizeof(rm.request_index_));
+ // name, step_id, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
+ memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
+ sizeof(rm.name_size_));
+ memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
sizeof(rm.remote_addr_));
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
- }
- // step_id
- if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
- (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
}
// is_dead, data_type, tensor_shape, tensor_bytes
- if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
@@ -1322,7 +1322,30 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
sizeof(rm.tensor_bytes_));
}
- return string(message, kMessageTotalBytes);
+ // checksum
+#ifdef RDMA_DATA_VALIDATION
+ memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
+#endif
+ // error status
+ if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ ::grpc::Status gs = ToGrpcStatus(rm.status_);
+ ErrorStatusProto gsProto;
+ gsProto.set_error_code(gs.error_code());
+ gsProto.set_error_message(gs.error_message());
+ gsProto.set_error_details(gs.error_details());
+ uint32_t gsProtoSize = gsProto.ByteSize();
+ if (gsProtoSize + 4 > kErrorStatusMaxSize) {
+ LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
+ << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
+ << " bytes). Truncated.";
+ gsProtoSize = kErrorStatusMaxSize - 4;
+ }
+ uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
+ *proto_size = gsProtoSize;
+ gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
+ message_size += gsProtoSize + 4;
+ }
+ return string(message, message_size);
}
// Parse a RdmaMessage according to the pre-defined format
@@ -1335,26 +1358,24 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
char* message = static_cast<char*>(buffer);
// type
rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
- // name_size_
- memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
- // name
- rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
- // buffer_size, remote_addr, rkey
- if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
- (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
- memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
- sizeof(rm.buffer_size_));
+ // request index
+ memcpy(&rm.request_index_, &message[kRequestIndexStartIndex],
+ sizeof(rm.request_index_));
+ // name, step_id, remote_addr, rkey
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
+ memcpy(&rm.name_size_, &message[kNameSizeStartIndex],
+ sizeof(rm.name_size_));
+ rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
sizeof(rm.remote_addr_));
memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
- }
- // step_id
- if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
- (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
}
// data_type, tensor_bytes, tensor_shape, is_dead
- if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+ if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
+ (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
+ (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
sizeof(rm.data_type_));
@@ -1363,6 +1384,291 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
sizeof(rm.tensor_bytes_));
}
+ // checksum
+#ifdef RDMA_DATA_VALIDATION
+ memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_));
+#endif
+ // error status
+ if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
+ ErrorStatusProto gsProto;
+ uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
+ CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
+ gsProtoSize))
+ << "Failed to parse error status proto from message. Aborting.";
+ ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
+ gsProto.error_message(), gsProto.error_details());
+ rm.status_ = FromGrpcStatus(gs);
+ }
+}
+
+//*****************************************************************************
+// RdmaMemoryMgr
+//*****************************************************************************
+
+ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) {
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ if (iter == std::end(mrs_) || iter->get()->addr > addr) {
+ return nullptr;
+ } else {
+ return iter->get();
+ }
+}
+
+void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name) {
+ if (length == 0) return;
+ ibv_mr* mr = ibv_reg_mr(pd_, addr, length,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". ["
+ << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]"
+ << " SIZE: 0x" << length << " (" << allocator_name << ").";
+ if (mr != nullptr) {
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ mrs_.insert(iter, {mr, &MRDeleter});
+ } else {
+ LOG(WARNING) << "Cannot register memory region";
+ }
+}
+
+void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) {
+ if (length == 0) return;
+ mutex_lock l(mrs_mu_);
+ auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
+ if (iter != std::end(mrs_) && iter->get()->addr == addr) {
+ mrs_.erase(iter);
+ RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey;
+
+ } else {
+ LOG(WARNING) << "Failed to de-register memory region";
+ }
+}
+
+const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData(
+ const std::string& tensor_name) {
+ mutex_lock l(tensor_meta_data_mu_);
+ auto it = tensors_meta_data_.find(tensor_name);
+ if (it == tensors_meta_data_.end()) {
+ return nullptr;
+ }
+ return &it->second;
+}
+
+const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData(
+ const std::string& tensor_name, DataType dtype, const TensorShape& shape,
+ bool is_dead, size_t proto_size) {
+ mutex_lock l(tensor_meta_data_mu_);
+ TensorMetaData& meta_data = tensors_meta_data_[tensor_name];
+ meta_data.data_type_ = dtype;
+ meta_data.tensor_shape_ = shape;
+ meta_data.proto_size_ = proto_size;
+ meta_data.is_dead_ = is_dead;
+ return &meta_data;
+}
+
+//*****************************************************************************
+// RdmaTensorRequest
+//*****************************************************************************
+
+RdmaTensorRequest::RdmaTensorRequest(
+ uint32_t index, const string& key, int64 step_id, RdmaChannel* channel,
+ Device* dst_dev, const Rendezvous::Args recv_args,
+ const RdmaTensorRequest::RecvDoneCallback& done)
+ : index_(index),
+ key_(key),
+ step_id_(step_id),
+ channel_(channel),
+ dst_dev_(dst_dev),
+ recv_args_(recv_args),
+ meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)),
+ result_tensor_(nullptr),
+ proxy_tensor_(nullptr),
+ rdma_addr_(nullptr),
+ mr_(nullptr),
+ done_(done) {}
+
+RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); }
+
+void RdmaTensorRequest::Done(const Status& s) {
+ Tensor val = std::move(*result_tensor_);
+
+#ifdef RDMA_DATA_VALIDATION
+ // Validate checksum
+ // Unfortunately we can't always do a Checksum directly on the result tensor.
+ // If the result tensor is on GPU, then we need to copy it back to CPU. If
+ // we happen to be in the midst of a proxy callback, then the copying will
+ // get stuck.
+ uint64_t checksum = (proxy_tensor_ != nullptr)
+ ? Checksum(nullptr, nullptr, *proxy_tensor_)
+ : Checksum(dst_dev_, recv_args_.device_context, val);
+ ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA");
+#endif
+
+ Rendezvous::Args recv_args = std::move(recv_args_);
+ bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_;
+ RecvDoneCallback done = done_;
+ DeallocateTensors();
+ channel_->RemoveTensorRequest(index_);
+ done(s, Rendezvous::Args(), recv_args, val, is_dead);
+}
+
+void RdmaTensorRequest::DeallocateTensors() {
+ if (result_tensor_ != nullptr) {
+ delete result_tensor_;
+ result_tensor_ = nullptr;
+ }
+ if (proxy_tensor_ != nullptr) {
+ delete proxy_tensor_;
+ proxy_tensor_ = nullptr;
+ }
+}
+
+bool RdmaTensorRequest::AllocateTensors() {
+ result_tensor_ =
+ new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs),
+ meta_data_->data_type_, meta_data_->tensor_shape_);
+
+ size_t tensor_size = result_tensor_->TotalBytes();
+ bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype());
+ if (can_memcpy) {
+ if (tensor_size == 0) {
+ return true;
+ }
+ rdma_addr_ = DMAHelper::base(result_tensor_);
+ mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
+#if GOOGLE_CUDA
+ if (mr_ == nullptr) {
+ // Can't RDMA directly to result. Use a proxy.
+ proxy_tensor_ =
+ new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0),
+ result_tensor_->dtype(), result_tensor_->shape());
+ rdma_addr_ = DMAHelper::base(proxy_tensor_);
+ mr_ =
+ RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
+ }
+#endif
+ } else {
+ uint32_t proto_size = meta_data_->proto_size_;
+ rdma_addr_ = malloc(proto_size);
+ mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size,
+ IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+ }
+ CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_
+ << ": " << key_;
+ return true;
+}
+
+void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
+ AllocateTensors();
+ bool on_host = recv_args_.alloc_attrs.on_host();
+ if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
+ (proxy_tensor_ == nullptr)) {
+#if GOOGLE_CUDA
+ // We need to sync the memory allocation on the GPU:
+ StreamGPUOp(dst_dev_, recv_args_.device_context, done);
+#endif
+ } else {
+ done(Status::OK());
+ }
+}
+
+void RdmaTensorRequest::Send(RdmaMessageType message_type) {
+ RdmaMessageBuffer* rb = channel_->tx_message_buffer_;
+ RdmaMessage rm;
+ rm.type_ = message_type;
+ rm.request_index_ = index_;
+ rm.name_size_ = key_.size();
+ rm.name_ = key_;
+ rm.step_id_ = step_id_;
+ rm.remote_addr_ = (uint64_t)rdma_addr_;
+ if (meta_data_ != nullptr) {
+ rm.data_type_ = meta_data_->data_type_;
+ rm.tensor_shape_ = meta_data_->tensor_shape_;
+ rm.is_dead_ = meta_data_->is_dead_;
+ rm.tensor_bytes_ = meta_data_->proto_size_;
+ } else {
+ rm.data_type_ = DT_INVALID;
+ }
+ rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
+
+ RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
+ << ": Sending " << MessageTypeToString(message_type) << " #"
+ << index_ << ": " << rm.name_ << " on " << rdma_addr_
+ << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
+
+ string message = RdmaMessage::CreateMessage(rm);
+ rb->EnqueueItem(message);
+ rb->SendNextItem();
+}
+
+void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
+ bool is_dead, size_t proto_size) {
+ meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData(
+ key_, dtype, shape, is_dead, proto_size);
+
+ DeallocateTensors();
+ AllocateTensorsAsync(
+ [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
+}
+
+void RdmaTensorRequest::RecvTensorContent() {
+ bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_);
+ size_t message_size =
+ can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
+ RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
+ << ": Received tensor content #" << index_ << ": " << key_
+ << " (Size: 0x" << std::hex << message_size << ")";
+
+ Tensor val;
+
+#if GOOGLE_CUDA
+ if (proxy_tensor_ != nullptr) {
+ CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_),
+ (void*)DMAHelper::base(result_tensor_),
+ result_tensor_->TotalBytes(), false);
+ GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context,
+ dst_dev_, result_tensor_,
+ [this](const Status& s) {
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ Done(s);
+ });
+ return;
+ }
+#endif
+
+ if (can_memcpy) {
+ Done(Status::OK());
+ } else {
+ RDMA_LOG(2) << "Decoding proto: " << key_
+ << " (Size: " << meta_data_->proto_size_ << ")";
+ TensorProto proto;
+ CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_))
+ << "fail to parse proto from array";
+ ibv_dereg_mr(mr_);
+ free(rdma_addr_);
+ Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs,
+ result_tensor_);
+ Done(s);
+ }
+}
+
+void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
+ if (result_tensor_ == nullptr) {
+ result_tensor_ = new Tensor();
+ }
+ LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString();
+ Done(status);
+}
+
+void RdmaTensorRequest::Start() {
+ meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
+ if (meta_data_ != nullptr) {
+ AllocateTensorsAsync(
+ [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
+ } else {
+ Send(RDMA_MESSAGE_TENSOR_REQUEST);
+ }
}
} // end namespace tensorflow
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
index fea2327d77..94203ee2b3 100644
--- a/tensorflow/contrib/verbs/rdma.h
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#define TENSORFLOW_CONTRIB_VERBS_RDMA_H_
#ifdef TENSORFLOW_USE_VERBS
@@ -27,6 +27,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
@@ -43,6 +44,11 @@ namespace tensorflow {
#define SL_DEFAULT 0
#define TRAFFIC_CLASS 0
+#define RDMA_LOG_0 LOG(INFO)
+#define RDMA_LOG_1 VLOG(1)
+#define RDMA_LOG_2 VLOG(2)
+#define RDMA_LOG(LEVEL) RDMA_LOG_##LEVEL
+
struct RdmaParams {
uint8_t port_num;
uint8_t sgid_index;
@@ -67,38 +73,305 @@ struct RemoteMR {
uint64_t remote_addr;
uint32_t rkey;
};
-enum BufferStatus {
- none,
- idle,
- busy
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
+
+enum RdmaMessageType {
+ RDMA_MESSAGE_META_DATA_UPDATE,
+ RDMA_MESSAGE_TENSOR_RE_REQUEST,
+ RDMA_MESSAGE_TENSOR_REQUEST,
+ RDMA_MESSAGE_ERROR_STATUS,
+};
+
+struct RdmaMessage {
+ RdmaMessageType type_;
+ uint16_t name_size_;
+ string name_;
+ int64 step_id_;
+ uint64_t request_index_;
+ union {
+ uint64_t remote_addr_;
+#ifdef RDMA_DATA_VALIDATION
+ uint64_t checksum_;
+#endif
+ };
+ uint32_t rkey_;
+ bool is_dead_;
+ DataType data_type_;
+ TensorShape tensor_shape_;
+ size_t tensor_bytes_;
+
+ // For error status:
+ Status status_;
+
+ // type|name_size|name|step_id|request_index|remote_addr/checksum|rkey|...
+ // 1B| 2B | 512| 8B | 8B | 8B | 4B |...
+ // ...|is_dead|data_type|tensor_shape|tensor_bytes|error_status |
+ // ...| 1B | XB | XB | 8B |size - 4B, proto - XB |
+ static const size_t kNameCapacity = 512;
+ static const size_t kTypeStartIndex = 0;
+ static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
+ static const size_t kNameStartIndex =
+ kNameSizeStartIndex + sizeof(name_size_);
+ static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
+ static const size_t kRequestIndexStartIndex =
+ kStepIdStartIndex + sizeof(step_id_);
+ static const size_t kRemoteAddrStartIndex =
+ kRequestIndexStartIndex + sizeof(request_index_);
+ static const size_t kChecksumStartIndex = kRemoteAddrStartIndex;
+ static const size_t kRkeyStartIndex =
+ kRemoteAddrStartIndex + sizeof(remote_addr_);
+ static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
+ static const size_t kDataTypeStartIndex =
+ kIsDeadStartIndex + sizeof(is_dead_);
+ static const size_t kTensorShapeStartIndex =
+ kDataTypeStartIndex + sizeof(data_type_);
+ static const size_t kTensorBytesStartIndex =
+ kTensorShapeStartIndex + sizeof(TensorShape);
+ static const size_t kErrorStatusStartIndex =
+ kTensorBytesStartIndex + sizeof(tensor_bytes_);
+ static const size_t kErrorStatusMaxSize = 4096;
+
+ static const size_t kMessageTotalBytes = kErrorStatusStartIndex;
+ static const size_t kRdmaMessageBufferSize =
+ kMessageTotalBytes + kErrorStatusMaxSize;
+ static string CreateMessage(const RdmaMessage& rm);
+ static void ParseMessage(RdmaMessage& rm, void* buffer);
+};
+
+// Immediate types for RDMA write
+enum RdmaImmDataType {
+ RDMA_IMM_MAX_REQUEST_ID = 0xFFFFFFFD,
+ RDMA_IMM_DATA_ACK = 0xFFFFFFFE,
+ RDMA_IMM_DATA_MESSAGE = 0xFFFFFFFF
};
-enum Location {
- local,
- remote
+
+// Write types for RDMA write-complete events
+enum RdmaWriteIDType {
+ RDMA_WRITE_ID_ACK,
+ RDMA_WRITE_ID_MESSAGE,
+ RDMA_WRITE_ID_TENSOR_WRITE
};
-enum BufferType {
- ACK,
- MESSAGE,
- TENSOR
+
+// Context for RDMA write-complete events
+class RdmaWriteID {
+ public:
+ RdmaWriteID(RdmaWriteIDType write_type, void* write_context)
+ : write_type(write_type), write_context(write_context) {}
+
+ RdmaWriteIDType write_type;
+ void* write_context;
};
-enum RdmaMessageType {
- RDMA_MESSAGE_ACK,
- RDMA_MESSAGE_BUFFER_IDLE,
- RDMA_MESSAGE_BUFFER_REQUEST,
- RDMA_MESSAGE_BUFFER_RESPONSE,
- RDMA_MESSAGE_TENSOR_REQUEST,
- RDMA_MESSAGE_TENSOR_WRITE
+
+// Tensor meta-data
+class TensorMetaData {
+ public:
+ TensorShape tensor_shape_;
+ DataType data_type_;
+ size_t proto_size_;
+ bool is_dead_;
+
+ std::ostream& print(std::ostream& out) const {
+ out << "Dtype = " << DataTypeString(data_type_)
+ << ", Shape = " << tensor_shape_.DebugString() << ", Proto size = 0x"
+ << std::hex << proto_size_ << ", Is dead = " << is_dead_;
+ return out;
+ }
+};
+
+inline std::ostream& operator<<(std::ostream& out,
+ const TensorMetaData& meta_data) {
+ return meta_data.print(out);
+}
+
+class RdmaChannel;
+
+void MRDeleter(ibv_mr* mr);
+using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
+
+// RdmaMemoryMgr
+// Manages the local meta-data cache, and the registered RDMA memory regions.
+class RdmaMemoryMgr {
+ public:
+ static RdmaMemoryMgr& Singleton() {
+ static RdmaMemoryMgr instance;
+ return instance;
+ }
+
+ // Memory regions
+ ibv_mr* FindMemoryRegion(void* addr, size_t length);
+ void InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name);
+ void EvictMemoryRegion(void* addr, size_t length);
+
+ // Tensor meta-data cache
+ const TensorMetaData* GetTensorMetaData(const std::string& tensor_name);
+ const TensorMetaData* SetTensorMetaData(const std::string& tensor_name,
+ DataType dtype,
+ const TensorShape& shape,
+ bool is_dead, size_t proto_size);
+
+ struct ibv_pd* pd_;
+
+ protected:
+ RdmaMemoryMgr() : pd_(nullptr) {}
+
+ static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
+ return ptr < reinterpret_cast<char*>(other->addr) + other->length;
+ }
+
+ private:
+ mutex tensor_meta_data_mu_;
+ std::unordered_map<std::string, TensorMetaData> tensors_meta_data_;
+
+ // Managed memory regions
+ mutex mrs_mu_;
+ std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(mrs_mu_);
};
-class RdmaBuffer;
+
+// RdmaTensorRequest
+// Represents a single tensor request.
+class RdmaTensorRequest {
+ public:
+ typedef Rendezvous::DoneCallback RecvDoneCallback;
+
+ // Creates a tensor request identified by index.
+ RdmaTensorRequest(uint32_t index, const string& key, int64 step_id,
+ RdmaChannel* channel, Device* dst_dev,
+ const Rendezvous::Args recv_args,
+ const RecvDoneCallback& done);
+ ~RdmaTensorRequest();
+
+ // Request unique index.
+ uint32_t index() { return index_; }
+
+ // Start the tensor request sequence.
+ //
+ // 1. Allocate the result tensor (and proxy tensor if required).
+ // 2. Send RDMA_MESSAGE_TENSOR_REQUEST to the remote side.
+ void Start();
+
+ // Receive tensor meta-data.
+ //
+ // 1. Update the local meta-data cache.
+ // 2. Reallocate the result tensor (and proxy tensor if required).
+ // 3. Re-send the request to the remote side.
+ void RecvTensorMetaData(DataType dtype, TensorShape shape, bool is_dead,
+ size_t proto_size);
+
+ // Receive tensor content (RDMA write was completed).
+ //
+ // Decode proto if required and/or move to GPU if the content was not
+ // written to it directly (GPU direct is not avaliable). Afterwards,
+ // invoke Done().
+ void RecvTensorContent();
+
+ // Receive error status (in case of a remote error).
+ // Invoke Done() with the status code.
+ void RecvErrorStatus(const Status& status);
+
+#ifdef RDMA_DATA_VALIDATION
+ // Receive tensor checksum
+ //
+ // For validation: Get and store the Tensor's expected checksum for the
+ // current request. Compare the result Tensor's checksum with the stored
+ // checksum right before invoking Done().
+ void RecvTensorChecksum(uint64_t checksum) { checksum_ = checksum; }
+#endif
+
+ private:
+ void Done(const Status& s);
+ void Send(RdmaMessageType message_type);
+ bool AllocateTensors();
+ void AllocateTensorsAsync(StatusCallback done);
+ void DeallocateTensors();
+
+ uint32_t index_;
+ string key_;
+ int64 step_id_;
+ RdmaChannel* channel_;
+ Device* dst_dev_;
+ Rendezvous::Args recv_args_;
+ const TensorMetaData* meta_data_;
+ Tensor* result_tensor_;
+ Tensor* proxy_tensor_;
+ void* rdma_addr_;
+ ibv_mr* mr_;
+ RecvDoneCallback done_;
+#ifdef RDMA_DATA_VALIDATION
+ uint64_t checksum_;
+#endif
+};
+
+// RdmaTensorResponse
+// Represents a single tensor response.
+class RdmaTensorResponse {
+ public:
+ // Creates a response for request message.
+ RdmaTensorResponse(RdmaChannel* channel, const RdmaMessage& rm)
+ : channel_(channel), rm_(rm) {}
+
+ void Update(const RdmaMessage& rm) { rm_ = rm; }
+
+ // Start the tensor response sequence.
+ //
+ // 1. Find the tensor in the local tag-match table and invoke RecvHandler.
+ // (Using RecvLocalAsync()).
+ // 2. Compare the tensor's meta-data to the meta-data in the message (taken
+ // from the requester's local cache).
+ // If meta-data changed:
+ // a. Clone the tensor to be sent later.
+ // b. Send a meta-data update message and wait for re-request.
+ // Else:
+ // a. Send the tensor's content (using direct RDMA write).
+ void Start();
+
+ // Resume the response sequence, after a re-request.
+ //
+ // 1. Send the tensor's content that was cloned earlier.
+ void Resume();
+
+ // Destroy the response's resources and remove it from the pending list.
+ void Destroy();
+
+ private:
+ void RecvHandler(Rendezvous::ParsedKey parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ bool is_dead);
+ void Clone(const Tensor& in, const TensorProto& proto, bool is_dead);
+ void Send(const Tensor& in, const TensorProto& proto, bool is_dead,
+ const Status& status);
+ bool TensorMetaDataChanged(const Tensor& in, bool is_dead);
+ Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
+ Device** src_dev);
+ void SendMetaData(const Tensor& in, const TensorProto& proto, bool is_dead);
+ void SendContent(const Tensor& in, const TensorProto& proto, bool is_dead);
+ void SendErrorStatus(const Status& status);
+
+ RdmaChannel* channel_;
+ RdmaMessage rm_; // The request message
+ Device* src_dev_ = nullptr;
+ TensorBuffer* src_buffer_ = nullptr;
+ void* src_addr_ = nullptr;
+ ibv_mr* mr_ = nullptr;
+ uint64_t checksum_ = 0;
+ bool meta_data_changed_ = false;
+
+ // Re-item:
+ TensorProto* proto_ = nullptr;
+ Tensor* tensor_ = nullptr;
+ bool is_dead_ = false;
+};
+
+class RdmaMessageBuffer;
// Class that represents the Rdma Adapter.
// Responsible for creation of the completion queue, and handling
// of work completions.
class RdmaAdapter {
friend class RdmaChannel;
- friend class RdmaBuffer;
- friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
- friend class RdmaTensorBuffer;
+ friend class RdmaTensorResponse;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
@@ -133,10 +406,10 @@ class RdmaAdapter {
// Responsible for connecting queue pairs.
class RdmaChannel {
friend class RdmaAdapter;
- friend class RdmaBuffer;
- friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
friend class RdmaTensorBuffer;
+ friend class RdmaTensorRequest;
+ friend class RdmaTensorResponse;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
@@ -146,22 +419,28 @@ class RdmaChannel {
~RdmaChannel();
inline const RdmaAddress& self() { return self_; }
RdmaAddress address() const;
- inline const std::vector<RdmaBuffer*>& message_buffers() const {
+ inline const std::vector<RdmaMessageBuffer*>& message_buffers() const {
return message_buffers_;
}
void Connect(const RdmaAddress& remoteAddr);
void Connect();
void Recv();
- RdmaBuffer* FindBuffer(const uint32_t index);
- RdmaBuffer* FindBuffer(const string& name);
- RdmaBuffer* FindOrCreateBuffer(const string& name,
- BufferType buffer_type = TENSOR);
- uint32_t LookupBufferIndex(const string& buffer_name);
void SetRemoteAddress(const RdmaAddress& ra, bool override);
- void InsertRecvCallback(const string& key, std::function<void()> recv_done);
- void RemoveRecvCallback(const string& key);
- void RunRecvCallback(const string& key);
- static const int kNumMessageBuffers = 4;
+
+ // Requests:
+ RdmaTensorRequest* InsertTensorRequest(
+ const string& key, int64 step_id, Device* dst_dev,
+ const Rendezvous::Args recv_args,
+ const RdmaTensorRequest::RecvDoneCallback& done);
+ void RemoveTensorRequest(uint32_t request_index);
+ RdmaTensorRequest* GetTensorRequest(uint32_t request_index);
+
+ // Responses:
+ RdmaTensorResponse* AddTensorResponse(const RdmaMessage& rm);
+ RdmaTensorResponse* UpdateTensorResponse(const RdmaMessage& rm);
+ void RemoveTensorResponse(uint32_t request_index);
+
+ static const int kNumMessageBuffers = 2;
static const int kPingRecvWrid = 0;
private:
@@ -179,36 +458,31 @@ class RdmaChannel {
string remote_name_;
ibv_qp* qp_;
mutex mu_;
- bool connected_ GUARDED_BY(bt_mu_) = false;
- RdmaAddress remote_ GUARDED_BY(bt_mu_);
- bool remote_set_ GUARDED_BY(bt_mu_) = false;
+ bool connected_ GUARDED_BY(mu_) = false;
+ RdmaAddress remote_ GUARDED_BY(mu_);
+ bool remote_set_ GUARDED_BY(mu_) = false;
mutex ct_mu_;
- typedef std::unordered_map<string, std::function<void()> > CallbackTable;
- CallbackTable callback_table_ GUARDED_BY(ct_mu_);
- mutex bt_mu_;
- typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable;
- BufferTable buffer_table_ GUARDED_BY(bt_mu_);
- typedef std::unordered_map<uint32_t, string> BufferIndexNameTable;
- BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_);
- typedef std::unordered_map<string, uint32_t> BufferNameIndexTable;
- BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_);
- RdmaBuffer* tx_message_buffer_;
- RdmaBuffer* rx_message_buffer_;
- RdmaBuffer* tx_ack_buffer_;
- RdmaBuffer* rx_ack_buffer_;
- std::vector<RdmaBuffer*> message_buffers_;
+ typedef std::unordered_map<uint32_t, RdmaTensorRequest> RequestTable;
+ RequestTable request_table_ GUARDED_BY(ct_mu_);
+ uint32_t request_serial_ GUARDED_BY(ct_mu_);
+ mutex responses_mu_;
+ typedef std::unordered_map<uint32_t, RdmaTensorResponse> ResponsesTable;
+ ResponsesTable responses_table_ GUARDED_BY(responses_mu_);
+ RdmaMessageBuffer* tx_message_buffer_;
+ RdmaMessageBuffer* rx_message_buffer_;
+ std::vector<RdmaMessageBuffer*> message_buffers_;
};
-// Class that represents a buffer for Rdma writes and reads.
-class RdmaBuffer {
+// Class that represents a buffer for Rdma message sending.
+class RdmaMessageBuffer {
friend class RdmaChannel;
friend class RdmaAdapter;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
public:
- explicit RdmaBuffer(RdmaChannel* channel, string name);
- virtual ~RdmaBuffer();
+ explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
+ ~RdmaMessageBuffer();
inline void* buffer() const { return buffer_; }
inline ibv_mr* self() const { return self_; }
@@ -223,13 +497,15 @@ class RdmaBuffer {
}
void FreeBuffer();
void EnqueueItem(string Item);
- virtual void SendNextItem() {};
+ void SendNextItem();
void CreateCPUBuffer(size_t size, bool lock = true);
void SetRemoteMR(RemoteMR rmi, bool override);
- uint32_t LookupBufferIndex(const string& buffer_name) {
- return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
- }
void Write(uint32_t imm_data, size_t buffer_size);
+ static void Write(const RdmaChannel* channel, uint32_t imm_data,
+ size_t buffer_size, uint64_t src_addr, uint32_t lkey,
+ uint64_t remote_addr, uint32_t rkey,
+ RdmaWriteIDType write_type, void* write_context);
+ static void SendAck(const RdmaChannel* channel);
protected:
const RdmaChannel* channel_;
@@ -245,126 +521,7 @@ class RdmaBuffer {
BufferStatus remote_status_ GUARDED_BY(mu_) = none;
};
-class RdmaAckBuffer : public RdmaBuffer {
- public:
- explicit RdmaAckBuffer(RdmaChannel* channel, string name);
- virtual ~RdmaAckBuffer() override {}
- void SendNextItem() override;
-};
-
-class RdmaMessageBuffer : public RdmaBuffer {
- friend class RdmaChannel;
- friend class RdmaAapater;
-
- public:
- explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
- virtual ~RdmaMessageBuffer() override {}
- void SendNextItem() override;
-};
-
-class RdmaTensorBuffer : public RdmaBuffer {
- public:
- explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
- virtual ~RdmaTensorBuffer() override;
- void SendNextItem() override;
- void PostCopyOperations(bool can_memcpy, size_t buffer_size,
- size_t tensor_bytes, const string& key,
- const Tensor& in, int64 step_id, bool is_dead,
- const string& key_with_step_id, const Tensor* copy,
- const TensorProto* proto, const StringPiece* copy_buf,
- const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args);
-
- void ReSendNextItem();
-
- private:
- Rendezvous::DoneCallback getRecvTensorCallback(
- const string& key_with_step_id, const string& key, int64 step_id,
- const Rendezvous::ParsedKey& parsed);
-
- struct ReItem {
- Rendezvous::Args send_args;
- Rendezvous::Args recv_args;
- Tensor in;
- bool is_dead;
-
- ReItem(const Rendezvous::Args& send_args_,
- const Rendezvous::Args& recv_args_, const Tensor& in_, bool is_dead_)
- : send_args(send_args_),
- recv_args(recv_args_),
- in(in_),
- is_dead(is_dead_) {
- if (send_args.device_context) {
- send_args.device_context->Ref();
- }
- if (recv_args.device_context) {
- recv_args.device_context->Ref();
- }
- }
-
- ~ReItem() {
- if (send_args.device_context) {
- send_args.device_context->Unref();
- }
- if (recv_args.device_context) {
- recv_args.device_context->Unref();
- }
- }
- };
- typedef std::map<string, ReItem*> Table;
- typedef Table::iterator Itable;
-
- std::queue<string> requeue GUARDED_BY(mu_);
- Table retable GUARDED_BY(mu_);
-};
-
-struct RdmaMessage {
- RdmaMessageType type_;
- uint16_t name_size_;
- string name_;
- int64 step_id_;
- uint64_t buffer_size_;
- uint64_t remote_addr_;
- uint32_t rkey_;
- bool is_dead_;
- DataType data_type_;
- TensorShape tensor_shape_;
- size_t tensor_bytes_;
-
- // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
- // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
- // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
- // ...| XB | XB | 8B |...
- //
- static const size_t kNameCapacity = 512;
- static const size_t kTypeStartIndex = 0;
- static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
- static const size_t kNameStartIndex =
- kNameSizeStartIndex + sizeof(name_size_);
- static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
- static const size_t kBufferSizeStartIndex =
- kStepIdStartIndex + sizeof(step_id_);
- static const size_t kRemoteAddrStartIndex =
- kBufferSizeStartIndex + sizeof(buffer_size_);
- static const size_t kRkeyStartIndex =
- kRemoteAddrStartIndex + sizeof(remote_addr_);
- static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
- static const size_t kDataTypeStartIndex =
- kIsDeadStartIndex + sizeof(is_dead_);
- static const size_t kTensorShapeStartIndex =
- kDataTypeStartIndex + sizeof(data_type_);
- static const size_t kTensorBytesStartIndex =
- kTensorShapeStartIndex + sizeof(TensorShape);
- static const size_t kTensorBufferStartIndex =
- kTensorBytesStartIndex + sizeof(tensor_bytes_);
- static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
- static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
- static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
- static string CreateMessage(const RdmaMessage& rm);
- static void ParseMessage(RdmaMessage& rm, void* buffer);
-};
-
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_H_
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 9cb307bcfa..369bd986df 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -16,11 +16,16 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include <fstream>
#include <vector>
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/common_runtime/bfc_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -53,7 +58,7 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
void RdmaMgr::SetupChannels() {
for (const auto& p : channel_table_) {
string worker_name = p.first;
- LOG(INFO) << "connecting to remote node " << worker_name;
+ RDMA_LOG(2) << "Connecting to remote node " << worker_name;
RdmaChannel* rc = p.second;
GetRemoteAddressRequest req;
GetRemoteAddressResponse resp;
@@ -78,39 +83,49 @@ void RdmaMgr::SetupChannels() {
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
}
// synchronous call
- Status s = client->GetRemoteAddress(&req, &resp);
- // save obtained remote addresses
- // connect to the remote channel
- if (s.ok()) {
- CHECK(worker_name.compare(resp.host_name()) == 0);
- RdmaAddress ra;
- ra.lid = resp.channel().lid();
- ra.qpn = resp.channel().qpn();
- ra.psn = resp.channel().psn();
- ra.snp = resp.channel().snp();
- ra.iid = resp.channel().iid();
- rc->SetRemoteAddress(ra, false);
- rc->Connect();
- int i = 0;
- int idx[] = {1, 0, 3, 2};
- for (const auto& mr : resp.mr()) {
- // the connections are crossed, i.e.
- // local tx_message_buffer <---> remote rx_message_buffer_
- // local rx_message_buffer <---> remote tx_message_buffer_
- // local tx_ack_buffer <---> remote rx_ack_buffer_
- // local rx_ack_buffer <---> remote tx_ack_buffer_
- // hence idx[] = {1, 0, 3, 2}.
- RdmaBuffer* rb = rc->message_buffers_[idx[i]];
- RemoteMR rmr;
- rmr.remote_addr = mr.remote_addr();
- rmr.rkey = mr.rkey();
- rb->SetRemoteMR(rmr, false);
- i++;
+ Status s;
+ int attempts = 0;
+ static const int max_num_attempts = 5;
+ do {
+ s = client->GetRemoteAddress(&req, &resp);
+ // save obtained remote addresses
+ // connect to the remote channel
+ if (s.ok()) {
+ CHECK(worker_name.compare(resp.host_name()) == 0);
+ RdmaAddress ra;
+ ra.lid = resp.channel().lid();
+ ra.qpn = resp.channel().qpn();
+ ra.psn = resp.channel().psn();
+ ra.snp = resp.channel().snp();
+ ra.iid = resp.channel().iid();
+ rc->SetRemoteAddress(ra, false);
+ rc->Connect();
+ int i = 0;
+ int idx[] = {1, 0};
+ for (const auto& mr : resp.mr()) {
+ // the connections are crossed, i.e.
+ // local tx_message_buffer <---> remote rx_message_buffer_
+ // local rx_message_buffer <---> remote tx_message_buffer_
+ // hence idx[] = {1, 0}.
+ RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
+ RemoteMR rmr;
+ rmr.remote_addr = mr.remote_addr();
+ rmr.rkey = mr.rkey();
+ rb->SetRemoteMR(rmr, false);
+ i++;
+ }
+ CHECK(i == RdmaChannel::kNumMessageBuffers);
+ } else {
+ LOG(ERROR) << "Connecting to " << worker_name << ": Got "
+ << s.error_message() << ". Retrying (" << (attempts + 1)
+ << "/" << max_num_attempts << ")...";
+ if (++attempts == max_num_attempts) {
+ break;
+ }
+ worker_env_->env->SleepForMicroseconds(2000000);
}
- CHECK(i == RdmaChannel::kNumMessageBuffers);
- } else {
- LOG(ERROR) << s.error_message();
- }
+ } while (!s.ok());
+ RDMA_LOG(0) << "Connected to remote node " << worker_name;
delete client;
}
}
@@ -144,19 +159,17 @@ bool RdmaMgr::ConnectivityCheck() {
ibv_wc_status s = rdma_adapter_->wc_[i].status;
// recv complete
if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
- CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
- rdma_adapter_->wc_[i].status)
- << "(" << rdma_adapter_->wc_[i].status
- << ") for PING_RECV_WRID";
+ CHECK(s == IBV_WC_SUCCESS)
+ << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+ << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
++rcnt;
// send complete
} else {
RdmaChannel* rc =
reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
- CHECK(s == IBV_WC_SUCCESS) << ": " << ibv_wc_status_str(
- rdma_adapter_->wc_[i].status)
- << "(" << rdma_adapter_->wc_[i].status
- << ") to " << rc->remote_name_;
+ CHECK(s == IBV_WC_SUCCESS)
+ << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
+ << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
++scnt;
}
} // for
@@ -183,6 +196,139 @@ RdmaChannel* RdmaMgr::FindChannel(const string& name) {
return iter->second;
}
+bool IsGDRAvailable() {
+#if defined(__APPLE__)
+ return false;
+#elif defined(PLATFORM_WINDOWS)
+ return false;
+#else
+ std::ifstream ifs("/proc/modules");
+ string line;
+ while (std::getline(ifs, line)) {
+ auto sep = line.find(' ');
+ CHECK_NE(sep, std::string::npos);
+ if (line.substr(0, sep) == "nv_peer_mem") {
+ return true;
+ }
+ }
+ return false;
+#endif
+}
+
+int TryToReadNumaNode(ibv_device* device) {
+#if defined(__APPLE__)
+ LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
+ return 0;
+#elif defined(PLATFORM_WINDOWS)
+ // Windows support for NUMA is not currently implemented. Return node 0.
+ return 0;
+#else
+ VLOG(2) << "Trying to read NUMA node for device: " << device->name;
+ static const int kUnknownNumaNode = -1;
+
+ auto filename = string(device->ibdev_path) + "/device/numa_node";
+
+ std::ifstream ifs(filename.c_str());
+ string content;
+ CHECK(std::getline(ifs, content));
+
+ int32 value;
+ if (strings::safe_strto32(content, &value)) {
+ if (value < 0) {
+ LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
+ << value
+ << "), but there must be at least one NUMA node"
+ ", so returning NUMA node zero";
+ return 0;
+ }
+ LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
+ return value;
+ }
+ return kUnknownNumaNode;
+#endif
+}
+
+void MRDeleter(ibv_mr* mr) {
+ if (mr) {
+ ibv_dereg_mr(mr);
+ }
+}
+
+// TODO(byronyi): remove this class duplicated from the one in
+// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
+class BasicCPUAllocator : public SubAllocator {
+ public:
+ ~BasicCPUAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ return port::AlignedMalloc(num_bytes, alignment);
+ }
+ void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
+};
+
+// TODO(byronyi): remove this class and its registration when the default
+// cpu_allocator() returns visitable allocator
+class BFCRdmaAllocator : public BFCAllocator {
+ public:
+ BFCRdmaAllocator()
+ : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
+ }
+};
+
+REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
+
+void RdmaMgr::InitAllocators() {
+ RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
+
+ Allocator* allocators[] = {
+#if GOOGLE_CUDA
+ ProcessState::singleton()->GetCUDAHostAllocator(0),
+ ProcessState::singleton()->GetCPUAllocator(0),
+#endif // GOOGLE_CUDA
+ cpu_allocator(),
+ };
+
+ using namespace std::placeholders;
+
+ std::set<Allocator*> instrumented_;
+
+ // Host memory allocators
+ for (Allocator* allocator : allocators) {
+ VisitableAllocator::Visitor alloc_visitor =
+ std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
+ &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
+ VisitableAllocator::Visitor free_visitor = std::bind(
+ &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
+
+ auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
+ CHECK(visitable_allocator)
+ << "is not visitable for instrumentation" << allocator->Name();
+ // Make sure we don't instrument the same allocator twice
+ if (instrumented_.find(allocator) == std::end(instrumented_)) {
+ visitable_allocator->AddAllocVisitor(alloc_visitor);
+ visitable_allocator->AddFreeVisitor(free_visitor);
+ instrumented_.insert(allocator);
+ LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
+ }
+ }
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ // Note we don't free allocated GPU memory so there is no free visitor
+ int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
+
+ char buf[8];
+ sprintf(buf, "gpu");
+ VisitableAllocator::Visitor cuda_alloc_visitor =
+ std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
+ &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
+
+ ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+}
+
} // end namespace tensorflow
#endif
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
index e711e60478..9fffc335bb 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#define TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
#ifdef TENSORFLOW_USE_VERBS
@@ -38,6 +38,7 @@ class RdmaMgr {
RdmaChannel* FindChannel(const string& key);
void SetupChannels();
bool ConnectivityCheck();
+ void InitAllocators();
const string& local_worker() { return local_worker_; }
private:
@@ -54,4 +55,4 @@ class RdmaMgr {
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 74f6681af3..ad3dce1784 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,10 +21,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
-#endif // GOOGLE_CUDA
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -36,11 +32,6 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr)
: BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {}
- void RecvPostCopyOps(const string& key, const string& key_with_step_id,
- const Rendezvous::Args& recv_args,
- const DoneCallback& done, const RdmaMessage& rm,
- RdmaChannel* rc, Tensor& val, const Status& s);
-
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
@@ -74,101 +65,18 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
string key(std::move(parsed.FullKey().ToString()));
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
- // insert callback
- rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
- recv_args, parsed, done]() {
- Status src_s, dst_s, s;
- Device* src_dev, *dst_dev;
- src_s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
- dst_s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
- if (!src_s.ok() || !dst_s.ok()) {
- s = src_s.ok() ? dst_s : src_s;
- LOG(ERROR) << "s is not ok, error code " << s.error_message();
- done(s, Args(), recv_args, Tensor(), true);
- return;
- }
- RdmaBuffer* rb = rc->FindBuffer(key);
- RdmaMessage rm;
- CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
- RdmaMessage::ParseMessage(rm, rb->buffer_);
- CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
- Tensor val;
- if (!rm.is_dead_) {
- void* input = static_cast<char*>(rb->buffer_) +
- RdmaMessage::kTensorBufferStartIndex;
- bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_);
- if (can_memcpy) {
- if (dst_dev->tensorflow_gpu_device_info() &&
- (!recv_args.alloc_attrs.on_host())) {
-#if GOOGLE_CUDA
- CHECK(recv_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
- Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
- memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
-
- Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
- Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
-
- GPUUtil::CopyCPUTensorToGPU(
- &copy, recv_args.device_context, dst_dev, &gpu_copy,
- [this, gpu_copy, key, key_with_step_id, recv_args, done, rm, rc](
- const Status& s) {
- CHECK(s.ok()) << "copy tensor to gpu sync";
- Tensor val;
- val = std::move(gpu_copy);
- RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc,
- val, s);
- });
-#endif // GOOGLE_CUDA
- return;
- } else {
- AllocatorAttributes host_alloc_attrs;
- host_alloc_attrs.set_gpu_compatible(true);
- host_alloc_attrs.set_on_host(true);
- Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs);
- Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
- memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
- val = std::move(copy);
- }
- } else {
- TensorProto proto;
- CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
- rb->size_);
- CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
- << "fail to parse proto from array";
- s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
- }
- }
- RecvPostCopyOps(key, key_with_step_id, recv_args, done, rm, rc, val, s);
- });
- // append key to message queue
- RdmaBuffer* rb = rc->tx_message_buffer_;
- RdmaMessage rm;
- rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
- rm.name_size_ = key.size();
- rm.name_ = key;
- rm.step_id_ = step_id_;
- string message = RdmaMessage::CreateMessage(rm);
- rb->EnqueueItem(message);
- rb->SendNextItem();
-}
-void RdmaRemoteRendezvous::RecvPostCopyOps(
- const string& key, const string& key_with_step_id,
- const Rendezvous::Args& recv_args, const DoneCallback& done,
- const RdmaMessage& rm, RdmaChannel* rc, Tensor& val, const Status& s) {
- rc->RemoveRecvCallback(key_with_step_id);
- RdmaMessage br;
- br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
- br.name_size_ = key.size();
- br.name_ = key;
- string message = RdmaMessage::CreateMessage(br);
- RdmaBuffer* tb = rc->tx_message_buffer_;
- tb->EnqueueItem(message);
- tb->SendNextItem();
- done(s, Args(), recv_args, val, rm.is_dead_);
+ Device* dst_dev;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+ CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), true);
+ return;
+ }
+
+ RdmaTensorRequest* request =
+ rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done);
+ request->Start();
}
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env)
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
index 2dedd6c48f..c0d6f59c48 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
#ifdef TENSORFLOW_USE_VERBS
@@ -60,4 +60,4 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr {
} // end namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index a606ef75a4..47ed83f521 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -104,6 +104,7 @@ Status VerbsServer::Start() {
[this] { verbs_service_->HandleRPCsLoop(); }));
rdma_mgr_->SetupChannels();
CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
+ rdma_mgr_->InitAllocators();
verbs_state_ = CONNECTED;
}
}
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h
index 855380129f..54ce8c1d47 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.h
+++ b/tensorflow/contrib/verbs/verbs_server_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#define TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
#ifdef TENSORFLOW_USE_VERBS
@@ -63,4 +63,4 @@ class VerbsServer : public GrpcServer {
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#endif // TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto
index 0df1fed4b9..abdae1d84f 100644
--- a/tensorflow/contrib/verbs/verbs_service.proto
+++ b/tensorflow/contrib/verbs/verbs_service.proto
@@ -50,6 +50,12 @@ message GetRemoteAddressResponse {
repeated MemoryRegion mr = 3;
}
+message ErrorStatusProto {
+ int32 error_code = 1;
+ string error_message = 2;
+ string error_details = 3;
+}
+
////////////////////////////////////////////////////////////////////////////////
//
// VerbsService
diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies.png b/tensorflow/contrib/verbs/verbs_with_0_copies.png
new file mode 100644
index 0000000000..0641e2fd50
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_with_0_copies.png
Binary files differ
diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies.xml b/tensorflow/contrib/verbs/verbs_with_0_copies.xml
new file mode 100644
index 0000000000..16130a961b
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_with_0_copies.xml
@@ -0,0 +1 @@
+<mxfile userAgent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.84 Safari/537.36" version="7.8.7" editor="www.draw.io" type="device"><diagram name="Page-1" id="74e2e168-ea6b-b213-b513-2b3c1d86103e">7Vxtc9o4EP41zKQfmsGW3/hIgPQ60/RyIZ1rPzHClsFXY1FZEOivP8mW8ZsAB2yXtHQ6jb2SJXl3n0e7K6cdMFhsPhC4nD9gB/kdtetsOmDYUVVFUw32g0u2scRUu7FgRjxHdEoFY+8nEsKk28pzUJjrSDH2qbfMC20cBMimORkkBL/ku7nYz8+6hDNUEoxt6Jel/3oOnQup0u2mDX8hbzYXU1u6aJhC+/uM4FUg5uuowI3+xM0LmIwl+odz6OCXjAiMOmBAMKbx1WIzQD7XbaK2+Ln7Pa27dRMU0CoP6CB+Yg39FUqWHC2MbhNlRK+D+APdDrh7mXsUjZfQ5q0vzPxMNqcLn90p7NL1fH+AfUzYfYAD1ulOzIAIRZu9y1R2L8+cCuEFomTLumx2mo8fEf5kiduX1DhWIptn7GIkQigcYrYbOlUKuxB6kevIkqjI8NkMd463zqnK+LHihrtjL0rfQ9+bBR3QZz185NK0lV3NxM9olHAJg0Q2ppDQm3dJE1tatjUjjqbOS+tfTSKbEskK2lqYcsua+r6PbUgRJwIUhJiEt03OqfI5xyhwbp5Hn8d/P02eRv98GY2f37VhAam2c7PVBVDGTg5Elmtzu1OCv6NMi2FbaOru5iuBVQLp/fgFefwqRhnAalcC4B3jngPghGxrR/ATstfPkTs+IAqHkMKbC/GQ2oD3ZenEsGNah+/ZNeTbLrTnqHkAPiHYLuxlHAjKVCBjg8q0+XsBGahtAlkxjkcryGGRnLjFhM7xDAfQH6XSu7yWMxr9D1G6FcEoXFHMROkInzBein579RjiFbGTEFIsjW3oM5R002IZX+NBbRPkQ+qt89HoWZpTG6LAiCQGBMUgJShc4iBshRvs9SfGDX4/3AZ2s7QbUcAAL7dRGsKvH79wyCPisWd/8hf/6EZv/2PlEeTsffs3B3dT+aX7tv4r0M20RbZfszff+GC3Or/dePRr7u6bmKgaJ2hlTkhS5fo4QTz6iD22lJ0pe728KU2jYKF4UeKp1Eh9QuA2023JO4QH5jELO4RZyECP9E9SttRH4hWkHrPTSTXm00rMd++RkD57Cw7cjjngf1UDLjjggmm4LPJGjN4kwhvMYTBDDmMccF8V53O8mK7C4xh3GHvY1MOMjYbMbzjCLgP3LW/zvePbfDiHS37p+mjT5xWfCKyOuBzaPgxDzz5Ioa5lI9vOIr5bRnxJzVNL1/TKiLckQYBaEfAZXesSVSeyM3kBFCI6tcgL8euUeKE0kNbt3jK/MUxLUSxD10wjP65SjW9OgHjiHm35y5k+Id0FuhflFIbSu1WtHlAVy1KApqs5U2pFU1Z1kcPDgoo70ikeUi5zfkNhyUl4OJh3gbypRUFTUuMUMeTQZoZHTH7Hudbj4aloWHiOE8Unsi0gH7M0wd+gzN+axH3UOqK28ob7Gf++qraK8U6bqpblw00VQqJM75GFyfI0r61yMM/+5MFadgVPwwc+xAvxorz0Jq7SdaITI8qM/e61S3/zqZsh8cvmQjigH9+S28zlRMK2i+2q7tWqKdmrW8rYrKIFtWr74/6M7Zwd1CwZdFcaztDBm4eJ1mqFA1Q4fn0jmU6yn+WQYlZESjtRrVpIdTTzxDi2OJBe9IWaaimleZR6ayOgQj29EfeTlNbOdT+j7H7gstzvcPZjgkaSqtIXEPUlVaC8JdR9rDqIo7Xf6VRVUlqMI2uCN9soQOXnDPcOyh4veJWOF0oDyyLj6PTkY7BmYGMXDs+p+IGu7/NPl45Fxa9S0ZEz1aHkdJf7aeBE77rAayReGoX0tEzjzQfxxePWdoP4JN6UAHyuVIKAyNFlCEeMSEnGUHzEPXY6vVbAXMX2ghkT6Ondc5QevFf3GZ05HnH9aHebe46DgibKBoqp5yzbKxt299Fb1rDFHOAku6pN2ZV/J/EnW9U0jlq115RRZamEYO0iL7o4CiDsHWmldgSu2+1y8ihBdvjQnzyMxuP+h9Ek/1Vcxt7xyCUWnjbgBRdWBwRWnqoVyTeq0uiyjkKgVq65Nmf8h9FzfzLss3+eRuPHvz+PR1cHkDgA0Np0AFm9rXH0XwnggP21Vglg/0lALfYv1s+vFpdY3NDbtHhT2XfNSXUi+LhYxP2ruCF3Qv47M8VRBQO9yvsOJYiyVLLm9773kO+E+VfPLpBulyzPHaSp7sRjXrqJRQFciMaQouWE2V70XGCKJtBxiBB8R9v4in+mPYk/0z6SGZ9w6nUMuTwvNqaGbnTKNUDXVaMa4KVh2MhjWJV86QRkGbZeB4ab/tWihjAsg1m31PvPBbpUPweBfoXtebAFkmCrOdjKvk+8wvYPhO3r9ucrsk9Att7mhpyMcUV2ceqC818+viueVF1xtwd3hqR2XRfu2G36XxzEJ8/p/yMBRv8D</diagram></mxfile> \ No newline at end of file
diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg
new file mode 100644
index 0000000000..8bc69b889d
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.jpg
Binary files differ
diff --git a/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml
new file mode 100644
index 0000000000..484e7c78ae
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_with_0_copies_phase1_protocol.xml
@@ -0,0 +1 @@
+<mxfile userAgent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36" version="7.8.4" editor="www.draw.io" type="device"><diagram name="Page-1" id="74e2e168-ea6b-b213-b513-2b3c1d86103e">7Vxbc5s4FP41nuk+pMMd8ujYTrYzTZqNk9nmyaOAbNhgRIWc2P31K4G4yzZ1AHtaZzJjOBLS8TnnOzeRDNTRcn2DQejeIgf6A0Vy1gN1PFAUWVMM+sEom4RiGpcJYYE9h0/KCVPvJ+REiVNXngOj0kSCkE+8sEy0URBAm5RoAGP0Xp42R3551xAsYI0wtYFfp/7rOcTlVFmS8oG/obdw+daWzgdegP26wGgV8P0GijqPf5LhJUjX4vMjFzjovUBSJwN1hBEiydVyPYI+k20qtuS56y2jGd8YBqTJA1bywBvwVzDl2PDpo1eO98b4IxsuE+PHijF1ReCaXADfWwQDdUhn+HBO8lF6teCf8SpRCIKUNiUAk09/pUOUqeJogRxvXaa2z01Ke8ECDnpkDCxDehG8ROzjgs4c+j6yAYGPMIgQjkoC64WBKQycT4+Tu+m3h9nD5J+nyfSxax62a6K0m1LaRYlxBpklS3T43fUInIbAZqPv1C9RmkuWPr2Ts6ffIKZMbQWLnEGQujaIlpDgDZ3CH7A4aLlTkw1+/567CCX1EG7BO2RuA3C3tMiWzqFJLzg6xUhNPUbrUH2A9ltia7eQgDEgoHOTa6juNnZjBj2GoE9M1UHc/X4xZq+erq8nDLPT+29308nWTU8LRqrSJ4xkQwCjikCgQ5MBfoswcdECBcCf5NSrssgK4vkPErLh+QxYEURJ+QpfEQpLYmQb7RYi5QutsJ3O4rzSQLqA6TRNLGwMfUC8t/L6H5Kc0pEDivHiON/wU+hQyDzAKERBBLsHDfN8XylM/WG0Cezu9/syj/XyY+VhZjvDPgP7gKGsJRL7LiMUbm7unx7R6F6UXT2dUp7XqSCmEHuUsZ/wqN/45HMnQztq8qQfw8lT0eDN9+LNM1vss85u1x75Xrp75hsdGBq0emhIqvAPhAb+6D3y6M6ZKi+VsipNo6KhhAf+VK6kIcZgU5gWsglR831Us1LL7plvWLvnW9rO+fQi4Ti3sEyGzQKmVguY1x6OyKO3hMyZmCP2W/MyBbeQYDdNy0cuCBbQoX5GvW6KchctX1bRfoQ7NCTZxEPU2YypWTFEdoH6nnO9y/25XuSCkF3Ofbgess5RDFWHX45tH0SRZ5eFNfd8f4R8hOMl0gZPAe9SHe+HodoS5HuKWOIFieoCgaa0D2K/mrwrVewn3NewX1tI1aXPpiwZpiXLlqFrplFeV27mUw6AZWoEfVlFi/5cGhxR9bpx+VmxLlVFtixZ1XSlpDCtqrCmhrB7WbVhbDnEDtSaHTzDqGYKLA8rKzoiGL3CVNUBCmBF+5zEk7exTfUMKf2KuVKPlRt8YOk5TpxpiJxzOfvowherdV+sCcxHaSP/qofCO/T7itqSjihqUYOjq+KKFUD3KBSW7H2VQBfbUqgiAw/jW7bCO6bap5+fkr7cID5BIlSrv8r4iVVThsDAusurVH1/BO2zvOI1VJZwHRx0FVMQdLsposyqBrVmgW57EfWRUGjWFLqlIXdidq/12kVQ6xlDTSKnXU+kAaZk4KZY5P1klXKloNDMA/PI6kJ6VeMtdSVq+8jtdg3UBgcUnRiZoEl1oJEZdSNTj2pku2sMU+2kdEnbSR2ULmrdX7d9FjxK8qLf6ShY0Fo78FSmvyOWETtiubl/aqSHftgaw0h45LGbq/hJWqzte+TEEm3rmHl2muwIYO7KjYDA62ERziF1p7ggdbbiFqEfXpfTUsr2ggUl6PndY5zBXyjbNIgoY3M/jmQuLdth0E2JXlLsZUO9VrP0g9QqakC2olb2FsifrNRqedCrVkXFAQ9vVSc3R3EaYWfpWK5ImphJEuOxBtnx7XB2O5lOhzeTWfntvILCk5VrLvWlAzM4sZ5b1mNLD5gtgfJFOWYbTTet3t/sTvnZa15n5W9Tvqr1qXxRO6xz5Sfv+J21L9C+1iv0t/fbW9F+loLHK1T71tloVe1na8hydr1Pa+iqMm+54E4JX5bLZH4TE2UGyv6Spboq902/ZH27akDRAUzL3/vag74Tlb96kUGyCeFAGfHOAIzIzKNWuk5IAVjywYjAcEZ1z2cuEYEz4DiYE17hJrmidgRmDiBgb/F7wOHTPuRSzTnGi6EbA1EXULHtE8SwXMawInhvSBVl8nobGO76j6I6wrAIZlJt9p8LdKF8dgL9DNuPwVYVJGLdwVb0tt8Ztn8gbD8Un7e+kXvG/i9hX+8zZKdrnLFf3boCj9P3AA2PAs+424I7Q9D0bgt39Db/1wTJuXX+/x/Uyf8=</diagram></mxfile> \ No newline at end of file
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index f2f66fc567..29c515121e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -433,6 +433,7 @@ tf_cuda_library(
"framework/cancellation.h",
"framework/common_shape_fns.h",
"framework/control_flow.h", # TODO(josh11b): Make internal?
+ "framework/dataset.h",
"framework/device_base.h",
"framework/function.h",
"framework/graph_def_util.h",
@@ -594,6 +595,7 @@ cc_library(
tf_gen_op_libs(
is_external = False,
op_lib_names = [
+ "batch_ops",
"bitwise_ops",
"candidate_sampling_ops",
"checkpoint_ops",
@@ -675,6 +677,7 @@ cc_library(
deps = [
":array_ops_op_lib",
":audio_ops_op_lib",
+ ":batch_ops_op_lib",
":bitwise_ops_op_lib",
":candidate_sampling_ops_op_lib",
":checkpoint_ops_op_lib",
@@ -810,6 +813,7 @@ cc_library(
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
+ "//tensorflow/core/kernels:batch_kernels",
"//tensorflow/core/kernels:bincount_op",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
@@ -1070,8 +1074,8 @@ cc_library(
":protos_all_cc_impl",
"//third_party/eigen3",
"//third_party/fft2d:fft2d_headers",
- "@fft2d//:fft2d",
- "@gemmlowp//:gemmlowp",
+ "@fft2d",
+ "@gemmlowp",
"@protobuf_archive//:protobuf",
],
alwayslink = 1,
@@ -1321,6 +1325,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "framework/function_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "framework/function.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "framework/graph_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "framework/graph.proto",
@@ -1885,6 +1896,13 @@ cc_library(
],
)
+tf_cuda_library(
+ name = "cuda_device_functions",
+ hdrs = ["util/cuda_device_functions.h"],
+ visibility = ["//visibility:public"],
+ deps = [":framework_lite"],
+)
+
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
@@ -2333,7 +2351,7 @@ cc_library(
":lib_internal",
":proto_text",
"//third_party/eigen3",
- "@local_config_sycl//sycl:sycl",
+ "@local_config_sycl//sycl",
],
alwayslink = 0,
)
diff --git a/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt b/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt
new file mode 100644
index 0000000000..aea11b64fd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Batch.pbtxt
@@ -0,0 +1,42 @@
+op {
+ graph_op_name: "Batch"
+ summary: "Batches all input tensors nondeterministically."
+ description: <<END
+When many instances of this Op are being run concurrently with the same
+container/shared_name in the same device, some will output zero-shaped Tensors
+and others will output Tensors of size up to max_batch_size.
+
+All Tensors in in_tensors are batched together (so, for example, labels and
+features should be batched with a single instance of this operation.
+
+Each invocation of batch emits an `id` scalar which will be used to identify
+this particular invocation when doing unbatch or its gradient.
+
+Each op which emits a non-empty batch will also emit a non-empty batch_index
+Tensor, which, is a [K, 3] matrix where each row contains the invocation's id,
+start, and length of elements of each set of Tensors present in batched_tensors.
+
+Batched tensors are concatenated along the first dimension, and all tensors in
+in_tensors must have the first dimension of the same size.
+
+in_tensors: The tensors to be batched.
+num_batch_threads: Number of scheduling threads for processing batches of work.
+ Determines the number of batches processed in parallel.
+max_batch_size: Batch sizes will never be bigger than this.
+batch_timeout_micros: Maximum number of microseconds to wait before outputting
+ an incomplete batch.
+allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does
+ nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
+ batches up to one of those sizes. The entries must increase monotonically, and
+ the final entry must equal max_batch_size.
+grad_timeout_micros: The timeout to use for the gradient. See Unbatch.
+batched_tensors: Either empty tensors or a batch of concatenated Tensors.
+batch_index: If out_tensors is non-empty, has information to invert it.
+container: Controls the scope of sharing of this batch.
+id: always contains a scalar with a unique ID for this invocation of Batch.
+shared_name: Concurrently running instances of batch in the same device with the
+ same container and shared_name will batch their elements together. If left
+ empty, the op name will be used as the shared name.
+T: the types of tensors to be batched.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
index a72f2bfe5f..118d0e2178 100644
--- a/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
@@ -30,9 +30,8 @@ END
attr {
name: "resize_align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1),
-which exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
attr {
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorGetNextSync.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextSync.pbtxt
new file mode 100644
index 0000000000..641679e8ea
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextSync.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IteratorGetNextSync"
+ summary: "Gets the next output from the given iterator."
+ description: <<END
+This operation is a synchronous version IteratorGetNext. It should only be used
+in situations where the iterator does not block the calling thread, or where
+the calling thread is not a member of the thread pool used to execute parallel
+operations (e.g. in eager mode).
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
index 6b3ba72e53..a08ed710b7 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize quantized `images` to `size` using quantized bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
index 6dc321a544..317ad263cc 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using area interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
index 06e645e3ee..d4f8233d25 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using bicubic interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
index bf5201d82e..eeb0680ab8 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
@@ -25,9 +25,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of bicubic interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
index 0768e437fa..0673baa703 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
index fba64203c2..9a1a5fb69a 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
@@ -25,9 +25,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
index a74db4c9dc..e6f8dc1941 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using nearest neighbor interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
index 4ef1547eb4..8d52ca8334 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
@@ -24,9 +24,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of nearest neighbor interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorGetItem.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorGetItem.pbtxt
new file mode 100644
index 0000000000..2869967d83
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorGetItem.pbtxt
@@ -0,0 +1,11 @@
+op {
+ graph_op_name: "TensorListGetItem"
+ summary: "Returns the item in the list with the given index."
+ description: <<END
+input_handle: the list
+index: the position in the list from which an element will be retrieved
+item: the element at that position
+
+
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListElementShape.pbtxt
new file mode 100644
index 0000000000..ee20f4575c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListElementShape.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "TensorListElementShape"
+ summary: "The shape of the elements of the given list, as a tensor."
+ description: <<END
+ input_handle: the list
+ element_shape: the shape of elements of the list
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListReserve.pbtxt
new file mode 100644
index 0000000000..b5640f0ffa
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListReserve.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "TensorListReserve"
+ summary: "List of the given size with empty elements."
+ description: <<END
+element_shape: the shape of the future elements of the list
+num_elements: the number of elements to reserve
+handle: the output list
+element_dtype: the desired type of elements in the list.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorSetItem.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorSetItem.pbtxt
new file mode 100644
index 0000000000..682cf69ee2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorSetItem.pbtxt
@@ -0,0 +1,11 @@
+op {
+ graph_op_name: "TensorListSetItem"
+ summary: "Sets the index-th position of the list to contain the given tensor."
+ description: <<END
+input_handle: the list
+index: the position in the list to which the tensor will be assigned
+item: the element to be assigned to that position
+output_handle: the new list, with the element in the proper position
+
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Unbatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_Unbatch.pbtxt
new file mode 100644
index 0000000000..6d10ea606d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Unbatch.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "Unbatch"
+ summary: "Reverses the operation of Batch for a single output Tensor."
+ description: <<END
+An instance of Unbatch either receives an empty batched_tensor, in which case it
+asynchronously waits until the values become available from a concurrently
+running instance of Unbatch with the same container and shared_name, or receives
+a non-empty batched_tensor in which case it finalizes all other concurrently
+running instances and outputs its own element from the batch.
+
+batched_tensor: The possibly transformed output of Batch. The size of the first
+ dimension should remain unchanged by the transformations for the operation to
+ work.
+batch_index: The matching batch_index obtained from Batch.
+id: The id scalar emitted by Batch.
+unbatched_tensor: The Tensor corresponding to this execution.
+timeout_micros: Maximum amount of time (in microseconds) to wait to receive the
+ batched input tensor associated with a given invocation of the op.
+container: Container to control resource sharing.
+shared_name: Instances of Unbatch with the same container and shared_name are
+ assumed to possibly belong to the same batch. If left empty, the op name will
+ be used as the shared name.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnbatchGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnbatchGrad.pbtxt
new file mode 100644
index 0000000000..487b4218d5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnbatchGrad.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "UnbatchGrad"
+ summary: "Gradient of Unbatch."
+ description: <<END
+Acts like Batch but using the given batch_index index of batching things as they
+become available. This ensures that the gradients are propagated back in the
+same session which did the forward pass.
+
+original_input: The input to the Unbatch operation this is the gradient of.
+batch_index: The batch_index given to the Unbatch operation this is the gradient
+of.
+grad: The downstream gradient.
+id: The id scalar emitted by Batch.
+batched_grad: The return value, either an empty tensor or the batched gradient.
+container: Container to control resource sharing.
+shared_name: Instances of UnbatchGrad with the same container and shared_name
+ are assumed to possibly belong to the same batch. If left empty, the op name
+ will be used as the shared name.
+END
+}
diff --git a/tensorflow/core/api_def/update_api_def.cc b/tensorflow/core/api_def/update_api_def.cc
index 1a6d15ec68..ea9a148260 100644
--- a/tensorflow/core/api_def/update_api_def.cc
+++ b/tensorflow/core/api_def/update_api_def.cc
@@ -224,14 +224,14 @@ void RemoveDocs(const std::vector<const OpDef*>& ops,
}
} // namespace
-// Returns ApiDef text representation in multi-line format
+// Returns ApiDefs text representation in multi-line format
// constructed based on the given op.
string CreateApiDef(const OpDef& op) {
- ApiDef api_def;
- FillBaseApiDef(&api_def, op);
+ ApiDefs api_defs;
+ FillBaseApiDef(api_defs.add_op(), op);
const std::vector<string> multi_line_fields = {"description"};
- string new_api_defs_str = api_def.DebugString();
+ string new_api_defs_str = api_defs.DebugString();
return PBTxtToMultiline(new_api_defs_str, multi_line_fields);
}
diff --git a/tensorflow/core/api_def/update_api_def.h b/tensorflow/core/api_def/update_api_def.h
index 5eae7e528e..1e285c0688 100644
--- a/tensorflow/core/api_def/update_api_def.h
+++ b/tensorflow/core/api_def/update_api_def.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
+#ifndef TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
+#define TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
// Functions for updating ApiDef when new ops are added.
#include "tensorflow/core/framework/op_def.pb.h"
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
-// Returns ApiDef text representation in multi-line format
+// Returns ApiDefs text representation in multi-line format
// constructed based on the given op.
string CreateApiDef(const OpDef& op);
@@ -42,4 +42,4 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir,
const string& op_file_pattern);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
+#endif // TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
diff --git a/tensorflow/core/api_def/update_api_def_test.cc b/tensorflow/core/api_def/update_api_def_test.cc
index 8948f2c1d5..4200c9da23 100644
--- a/tensorflow/core/api_def/update_api_def_test.cc
+++ b/tensorflow/core/api_def/update_api_def_test.cc
@@ -173,30 +173,32 @@ description: "Description\nfor Op1."
OpDef op;
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
- const string expected_api_def = R"(graph_op_name: "Op1"
-in_arg {
- name: "a"
- description: <<END
+ const string expected_api_def = R"(op {
+ graph_op_name: "Op1"
+ in_arg {
+ name: "a"
+ description: <<END
Description for a.
END
-}
-out_arg {
- name: "output"
- description: <<END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
Description for output.
END
-}
-attr {
- name: "b"
- description: <<END
+ }
+ attr {
+ name: "b"
+ description: <<END
Description for b.
END
-}
-summary: "Summary for Op1."
-description: <<END
+ }
+ summary: "Summary for Op1."
+ description: <<END
Description
for Op1.
END
+}
)";
EXPECT_EQ(expected_api_def, CreateApiDef(op));
}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index e9bdd922ba..20c59ad42b 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1143,8 +1143,8 @@ Status DirectSession::GetOrCreateExecutors(
options.debug_options = run_state_args->debug_options;
}
- std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
+ std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
// The executor_lock_ is intentionally released while executor is
// being created.
diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h
index 0bf6699f5a..3ddb26de92 100644
--- a/tensorflow/core/common_runtime/function_testlib.h
+++ b/tensorflow/core/common_runtime/function_testlib.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/function.h"
@@ -34,4 +34,4 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
} // namespace test
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0e5b6b7ef8..933d700f60 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -762,7 +762,7 @@ int64 MinSystemMemory(int64 available_memory) {
// is necessary.
min_system_memory *= 2;
#endif
-#if defined(NVIDIA_TEGRA)
+#if defined(ANDROID_TEGRA)
// 1GB system mem for NVIDIA Tegra devices since they use the same mem for RAM and Video RAM
min_system_memory = 1<<30;
#endif
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index ff81ccd432..4e9c4abce1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
#include "tensorflow/core/lib/gtl/int_type.h"
@@ -85,4 +85,4 @@ TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index 78e51c84c1..6d196b16ed 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
@@ -58,4 +58,4 @@ class GpuIdUtil {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_UTILS_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h
index 006b2ca448..2d49a64c0f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
#include "tensorflow/core/framework/allocator.h"
@@ -33,4 +33,4 @@ class GpuManagedAllocator : public Allocator {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 8f3a082134..8477cea126 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -60,4 +60,4 @@ class GraphOptimizer {
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
diff --git a/tensorflow/core/common_runtime/memory_types.h b/tensorflow/core/common_runtime/memory_types.h
index fa0a7595f3..f854acfdc5 100644
--- a/tensorflow/core/common_runtime/memory_types.h
+++ b/tensorflow/core/common_runtime/memory_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/graph/graph.h"
@@ -45,4 +45,4 @@ Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_
diff --git a/tensorflow/core/common_runtime/pending_counts.h b/tensorflow/core/common_runtime/pending_counts.h
index 5707f52592..5e1925c401 100644
--- a/tensorflow/core/common_runtime/pending_counts.h
+++ b/tensorflow/core/common_runtime/pending_counts.h
@@ -1,5 +1,5 @@
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
@@ -328,4 +328,4 @@ inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 38003b7726..a1adc4b6b3 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
#include <unordered_map>
@@ -173,4 +173,4 @@ class ProcessFunctionLibraryRuntime {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_
diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h
index 57c83c2e6f..9d31b1aecb 100644
--- a/tensorflow/core/common_runtime/profile_handler.h
+++ b/tensorflow/core/common_runtime/profile_handler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/types.h"
@@ -80,4 +80,4 @@ class ProfileHandler {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index c5c204d4fa..fe4df1c106 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/util/device_name_utils.h"
@@ -134,4 +134,4 @@ class RenamedDevice : public Device {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h
index 3b6354603b..aad910f6d8 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.h
+++ b/tensorflow/core/common_runtime/rendezvous_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#include <map>
@@ -49,4 +49,4 @@ Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h
index da42c30ce9..75eb5bf0d2 100644
--- a/tensorflow/core/common_runtime/shape_refiner.h
+++ b/tensorflow/core/common_runtime/shape_refiner.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
#include <vector>
@@ -303,4 +303,4 @@ class ShapeRefiner {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_
diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.h b/tensorflow/core/common_runtime/stats_publisher_interface.h
index b285420798..f063ee5297 100644
--- a/tensorflow/core/common_runtime/stats_publisher_interface.h
+++ b/tensorflow/core/common_runtime/stats_publisher_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/profile_handler.h"
@@ -61,4 +61,4 @@ std::unique_ptr<StatsPublisherInterface> CreateNoOpStatsPublisher(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 2db7ebd795..9e152aa082 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -145,6 +145,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
],
)
@@ -556,3 +557,47 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:array",
],
)
+
+cc_library(
+ name = "request_id",
+ srcs = ["request_id.cc"],
+ hdrs = ["request_id.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_cc_test(
+ name = "request_id_test",
+ size = "small",
+ srcs = ["request_id_test.cc"],
+ deps = [
+ ":request_id",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "recent_request_ids",
+ srcs = ["recent_request_ids.cc"],
+ hdrs = ["recent_request_ids.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "recent_request_ids_test",
+ size = "small",
+ srcs = ["recent_request_ids_test.cc"],
+ deps = [
+ ":recent_request_ids",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index 3deb80dff7..d3ca350e36 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
@@ -74,4 +74,4 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h
index 5fc21d3a1e..c20b40329a 100644
--- a/tensorflow/core/distributed_runtime/local_master.h
+++ b/tensorflow/core/distributed_runtime/local_master.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
#include <memory>
@@ -98,4 +98,4 @@ class LocalMaster : public MasterInterface {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index dcc25e4426..9d4a1eb8a1 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1448,6 +1448,8 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
const auto count = run_state->count;
pss.collect_timeline =
req.options().trace_level() == RunOptions::FULL_TRACE;
+ pss.collect_rpcs =
+ req.options().trace_level() == RunOptions::FULL_TRACE;
pss.report_tensor_allocations_upon_oom =
req.options().report_tensor_allocations_upon_oom();
@@ -1610,6 +1612,8 @@ Status MasterSession::DoRunWithLocalExecution(
TRACEPRINTF("stepid %llu", step_id);
pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE;
+ pss.collect_rpcs =
+ req.options().trace_level() == RunOptions::FULL_TRACE;
pss.report_tensor_allocations_upon_oom =
req.options().report_tensor_allocations_upon_oom();
// Build the cost model every 'build_cost_model_every' steps after skipping an
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 7113d73dd7..79fa6f926e 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
@@ -702,4 +702,4 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW
+#endif // TENSORFLOW
diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.h b/tensorflow/core/distributed_runtime/partial_run_mgr.h
index af56e723a9..e95f4da6c3 100644
--- a/tensorflow/core/distributed_runtime/partial_run_mgr.h
+++ b/tensorflow/core/distributed_runtime/partial_run_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
#include <unordered_map>
@@ -84,4 +84,4 @@ class PartialRunMgr {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc
new file mode 100644
index 0000000000..c30879406c
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
+ : circular_buffer_(num_tracked_request_ids) {
+ set_.reserve(num_tracked_request_ids);
+}
+
+Status RecentRequestIds::TrackUnique(int64 request_id,
+ const string& method_name,
+ const protobuf::Message& request) {
+ mutex_lock l(mu_);
+ if (request_id == 0) {
+ // For backwards compatibility, allow all requests with request_id 0.
+ return Status::OK();
+ }
+ if (set_.count(request_id) > 0) {
+ // Note: RecentRequestIds is not strict LRU because we don't update
+ // request_id's age in the circular_buffer_ if it's tracked again. Strict
+ // LRU is not useful here because returning this error will close the
+ // current Session.
+ return errors::Aborted("The same ", method_name,
+ " request was received twice. ",
+ request.ShortDebugString());
+ }
+
+ // Remove the oldest request_id from the set_. circular_buffer_ is
+ // zero-initialized, and zero is never tracked, so it's safe to do this even
+ // when the buffer is not yet full.
+ set_.erase(circular_buffer_[next_index_]);
+ circular_buffer_[next_index_] = request_id;
+ set_.insert(request_id);
+ next_index_ = (next_index_ + 1) % circular_buffer_.size();
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h
new file mode 100644
index 0000000000..e8e45331dd
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/recent_request_ids.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
+
+#include <vector>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+// RecentRequestIds tracks recent 64-bit request_ids. When maximum capacity is
+// reached, the oldest request_id is evicted. Thread safe.
+//
+// Some RPCs like RecvTensor are unsafe to retry. For example, RecvTensor pairs
+// one sender and one receiver, and the receiver waits for the sender's tensor.
+// Retried RecvTensor requests are problematic, because the original RecvTensor
+// request may have consumed the sender's tensor, so a retried request might
+// block forever. RecentRequestIds identifies retried requests, so we can fail
+// them instead of blocking forever.
+//
+// Internally, recent request_ids are stored in two data structures: a set and a
+// circular buffer. The set is used for efficient lookups, and the circular
+// buffer tracks the oldest request_id. When the buffer is full, the new
+// request_id replaces the oldest request_id in the circular buffer, and the
+// oldest request_id is removed from the set.
+class RecentRequestIds {
+ public:
+ // num_tracked_request_ids should be much larger than the number of RPCs that
+ // can be received in a small time window. For example, we observed a peak RPC
+ // rate of ~700 RecvTensor RPC/s when training inception v3 on TPUs, so we
+ // currently set num_tracked_request_ids to 100,000 for RecvTensor.
+ RecentRequestIds(int num_tracked_request_ids);
+
+ // Returns OK iff request_id has not been seen in the last
+ // num_tracked_request_ids insertions. For backwards compatibility, this
+ // always returns OK for request_id 0. The method_name and the request's
+ // ShortDebugString are added to returned errors.
+ Status TrackUnique(int64 request_id, const string& method_name,
+ const protobuf::Message& request);
+
+ private:
+ mutex mu_;
+ // next_index_ indexes into circular_buffer_, and points to the next storage
+ // space to use. When the buffer is full, next_index_ points at the oldest
+ // request_id.
+ int next_index_ GUARDED_BY(mu_) = 0;
+ std::vector<int64> circular_buffer_ GUARDED_BY(mu_);
+ gtl::FlatSet<int64> set_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
diff --git a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc
new file mode 100644
index 0000000000..9a0facf540
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+Status TrackUnique(int64 request_id, RecentRequestIds* recent_request_ids) {
+ RecvTensorRequest request;
+ request.set_request_id(request_id);
+ return recent_request_ids->TrackUnique(request_id, "recent_request_ids_test",
+ request);
+}
+
+// request_id 0 is always valid.
+TEST(RecentRequestIds, Zero) {
+ RecentRequestIds recent_request_ids(1);
+ EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
+ EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
+ EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
+}
+
+TEST(RecentRequestIds, Unordered) {
+ // Capacity for 6 numbers.
+ RecentRequestIds recent_request_ids(6);
+
+ // Some unordered numbers to insert into request_id_set.
+ std::vector<int64> numbers = {53754, 23351, 164101, 7476,
+ 162432, 130761, 164102};
+
+ // Insert numbers[0..6) and check that all previously inserted numbers remain
+ // in the set.
+ for (int i = 0; i < 6; ++i) {
+ TF_EXPECT_OK(TrackUnique(numbers[i], &recent_request_ids));
+
+ for (int j = 0; j <= i; ++j) {
+ EXPECT_FALSE(TrackUnique(numbers[j], &recent_request_ids).ok())
+ << "i=" << i << " j=" << j;
+ }
+ }
+
+ // Insert numbers[6]. Inserting this 7th number should evict the first number
+ // from the set. The set should only contain numbers[1..7).
+ TF_EXPECT_OK(TrackUnique(numbers[6], &recent_request_ids));
+ for (int i = 1; i < 7; ++i) {
+ EXPECT_FALSE(TrackUnique(numbers[i], &recent_request_ids).ok())
+ << "i=" << i;
+ }
+
+ // Insert numbers[0] again. This should succeed because we just evicted it
+ // from the set.
+ TF_EXPECT_OK(TrackUnique(numbers[0], &recent_request_ids));
+}
+
+// Check that the oldest request_id is evicted.
+void TestOrdered(int num_request_ids) {
+ RecentRequestIds recent_request_ids(num_request_ids);
+
+ // Insert [1..101). The current number and the (num_request_ids - 1) preceding
+ // numbers should still be in the set.
+ for (int i = 1; i < 101; ++i) {
+ TF_EXPECT_OK(TrackUnique(i, &recent_request_ids));
+
+ for (int j = std::max(1, i - num_request_ids + 1); j <= i; ++j) {
+ EXPECT_FALSE(TrackUnique(j, &recent_request_ids).ok())
+ << "i=" << i << " j=" << j;
+ }
+ }
+}
+
+// Test eviction with various numbers of buckets.
+TEST(RecentRequestIds, Ordered2) { TestOrdered(2); }
+TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
+TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
+TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/request_id.cc b/tensorflow/core/distributed_runtime/request_id.cc
new file mode 100644
index 0000000000..230c6f9601
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/request_id.cc
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/request_id.h"
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+int64 GetUniqueRequestId() {
+ int64 request_id = 0;
+ while (request_id == 0) {
+ request_id = random::New64();
+ }
+ return request_id;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/request_id.h b/tensorflow/core/distributed_runtime/request_id.h
new file mode 100644
index 0000000000..a882b69ab1
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/request_id.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
+
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Returns a request_id for use with RecentRequestIds. This number will not be
+// zero, and must be unique over RecentRequestIds' window of
+// num_tracked_request_ids. See recent_request_ids.h for more details.
+int64 GetUniqueRequestId();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
diff --git a/tensorflow/core/distributed_runtime/request_id_test.cc b/tensorflow/core/distributed_runtime/request_id_test.cc
new file mode 100644
index 0000000000..e0dc9d9347
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/request_id_test.cc
@@ -0,0 +1,29 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/request_id.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+// Try requesting some request_ids and verify that none are zero.
+TEST(GetUniqueRequestId, Basic) {
+ for (int i = 0; i < 1000000; ++i) {
+ EXPECT_NE(GetUniqueRequestId(), 0);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 80640c806d..dade26abc6 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -186,6 +186,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:graph_mgr",
+ "//tensorflow/core/distributed_runtime:recent_request_ids",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker",
"//tensorflow/core/distributed_runtime:worker_cache",
@@ -270,6 +271,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:request_id",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
diff --git a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h
index 63b0f2272d..b2730a583b 100644
--- a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h
+++ b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
namespace tensorflow {
@@ -38,4 +38,4 @@ class AsyncServiceInterface {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
index 2ab0a40f33..ecad1274cc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/macros.h"
@@ -265,4 +265,4 @@ class Call : public UntypedCall<Service> {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index c662cde9be..de9840fca8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
#include <map>
#include <memory>
@@ -93,4 +93,4 @@ Status NewHostPortGrpcChannel(const string& target,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
index 95c2c935f0..d367b83ee7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#include "grpc++/grpc++.h"
@@ -41,4 +41,4 @@ class GrpcClientCQTag {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
index 8770dcc3ac..473604f257 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
#include <memory>
#include "tensorflow/core/platform/types.h"
@@ -34,4 +34,4 @@ AsyncServiceInterface* NewGrpcMasterService(Master* master,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index 412395c526..4e203e260a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
#include "grpc++/impl/codegen/async_stream.h"
#include "grpc++/impl/codegen/async_unary_call.h"
@@ -186,4 +186,4 @@ class MasterService final {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
index d661caaa60..c80668e899 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@@ -24,4 +24,4 @@ namespace tensorflow {
MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
index 8ad4133540..709c3833e7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
-#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#ifndef TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#define TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
#include <memory>
@@ -35,4 +35,4 @@ WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#endif // TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h
index b35d4843e8..dd114d39c6 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
#include "grpc++/impl/codegen/proto_utils.h"
#include "grpc++/support/slice.h"
@@ -231,4 +231,4 @@ class UnlimitedSizeProtoSerializationTraits {
: public UnlimitedSizeProtoSerializationTraits<MessageType> {}; \
} // namespace grpc
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index c3f513d492..8b12ac1461 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
#include <memory>
@@ -141,4 +141,4 @@ class GrpcServer : public ServerInterface {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index 300f727124..d87956a135 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
#include <memory>
#include <string>
@@ -130,4 +130,4 @@ class GrpcSession : public Session {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 3f80bdfb70..0b6f9474dd 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
#include <utility>
@@ -96,4 +96,4 @@ class RPCState : public GrpcClientCQTag {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
index 5e81b90189..4b3a03b1d7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
#include <memory>
#include <string>
@@ -70,4 +70,4 @@ class TestCluster {
} // end namespace test
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
index bb85478347..d5e7e9f5b3 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
#include <memory>
@@ -114,4 +114,4 @@ class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
index 17a307a6d9..7a35fdbca0 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@@ -29,4 +29,4 @@ WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
const string& local_target);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 15faf21daf..b20e744a97 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -354,7 +354,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
} // namespace
-GrpcWorker::GrpcWorker(WorkerEnv* worker_env) : Worker(worker_env) {}
+GrpcWorker::GrpcWorker(WorkerEnv* worker_env)
+ : Worker(worker_env), recv_tensor_recent_request_ids_(100000) {}
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
// buffers for a response object, to avoid extra protocol buffer serialization
@@ -363,11 +364,18 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
+ Status s = recv_tensor_recent_request_ids_.TrackUnique(
+ request->request_id(), "RecvTensor (GrpcWorker)", *request);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
- Status s = Rendezvous::ParseKey(key, &parsed);
+ s = Rendezvous::ParseKey(key, &parsed);
Device* src_dev = nullptr;
if (s.ok()) {
s = PrepareRecvTensor(parsed, &src_dev);
@@ -436,6 +444,24 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
});
}
+void GrpcWorker::LoggingAsync(const LoggingRequest* request,
+ LoggingResponse* response, StatusCallback done) {
+ auto env = this->env();
+ if (env) {
+ auto session_mgr = (SessionMgr*)env->session_mgr;
+ if (session_mgr) {
+ session_mgr->SetLogging(request->rpc_logging());
+ for (const auto& step_id : request->fetch_step_id()) {
+ session_mgr->RetrieveLogs(step_id, response);
+ }
+ if (request->clear()) {
+ session_mgr->ClearLogs();
+ }
+ }
+ }
+ done(Status::OK());
+}
+
WorkerEnv* GrpcWorker::env() { return env_; }
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
index 64d7c986da..3954af8ad8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/distributed_runtime/worker.h"
namespace grpc {
@@ -39,7 +40,13 @@ class GrpcWorker : public Worker {
::grpc::ByteBuffer* response,
StatusCallback done);
+ virtual void LoggingAsync(const LoggingRequest* request,
+ LoggingResponse* response, StatusCallback done);
+
WorkerEnv* env();
+
+ private:
+ RecentRequestIds recv_tensor_recent_request_ids_;
};
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env);
@@ -50,4 +57,4 @@ std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index fb23f8631f..1a5e2edfb2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
#include "grpc++/impl/codegen/async_stream.h"
#include "grpc++/impl/codegen/async_unary_call.h"
@@ -147,4 +147,4 @@ class WorkerService final {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index 72dfe5c062..067dc5dff5 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@@ -67,6 +68,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
done_ = std::move(done);
req_.set_step_id(step_id);
req_.set_rendezvous_key(key.data(), key.size());
+ req_.set_request_id(GetUniqueRequestId());
}
void Reset(WorkerCacheInterface* wc) {
diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h
index a064d20cdb..275f526d31 100644
--- a/tensorflow/core/distributed_runtime/server_lib.h
+++ b/tensorflow/core/distributed_runtime/server_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
#include <memory>
@@ -95,4 +95,4 @@ Status NewServer(const ServerDef& server_def,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 8db49e7f15..51b9547f53 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -43,8 +43,8 @@ SessionMgr::SessionMgr(
worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
- return strings::StrCat("/job:", server_def.job_name(),
- "/replica:0/task:", server_def.task_index());
+ return strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:",
+ server_def.task_index());
}
Status SessionMgr::CreateSession(const string& session,
@@ -64,8 +64,13 @@ Status SessionMgr::CreateSession(const string& session,
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
}
+ if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) {
+ worker_cache->SetLogging(this->is_logging_active_);
+ }
+
CHECK(!worker_env_->local_devices.empty())
<< "The WorkerEnv must have at least one device in `local_devices`.";
+
std::vector<Device*> renamed_devices;
for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
@@ -113,4 +118,77 @@ std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
return legacy_session_;
}
+void SessionMgr::SetLogging(bool active) {
+ mutex_lock l(mu_);
+ this->is_logging_active_ = active;
+ // Legacy Session
+ if (legacy_session_) {
+ auto* worker_cache = legacy_session_->worker_cache.get();
+ if (worker_cache) {
+ worker_cache->SetLogging(active);
+ }
+ }
+
+ for (const auto& session_kv : sessions_) {
+ auto session = session_kv.second.get();
+ if (session) {
+ auto* worker_cache = session->worker_cache.get();
+ if (worker_cache) {
+ worker_cache->SetLogging(active);
+ }
+ }
+ }
+}
+
+void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
+ LoggingResponse* response) {
+ mutex_lock l(mu_);
+ // Legacy Session
+ if (legacy_session_) {
+ auto* worker_cache = legacy_session_->worker_cache.get();
+ if (worker_cache) {
+ auto step_stats = StepStats();
+ if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
+ auto* labeled_step_stats = response->add_step();
+ labeled_step_stats->set_step_id(step_id);
+ labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
+ }
+ }
+ }
+ for (const auto& session_kv : sessions_) {
+ auto session = session_kv.second.get();
+ if (session) {
+ auto* worker_cache = session->worker_cache.get();
+ if (worker_cache) {
+ auto step_stats = StepStats();
+ if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
+ auto* labeled_step_stats = response->add_step();
+ labeled_step_stats->set_step_id(step_id);
+ labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
+ }
+ }
+ }
+ }
+}
+
+void SessionMgr::ClearLogs() {
+ mutex_lock l(mu_);
+ // Legacy Session
+ if (legacy_session_) {
+ auto* worker_cache = legacy_session_->worker_cache.get();
+ if (worker_cache) {
+ worker_cache->ClearLogs();
+ }
+ }
+
+ for (const auto& session_kv : sessions_) {
+ auto session = session_kv.second.get();
+ if (session) {
+ auto* worker_cache = session->worker_cache.get();
+ if (worker_cache) {
+ worker_cache->ClearLogs();
+ }
+ }
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h
index ba077c3acc..4c9702d522 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.h
+++ b/tensorflow/core/distributed_runtime/session_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
#include <functional>
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
@@ -56,6 +57,12 @@ class SessionMgr {
static string WorkerNameFromServerDef(const ServerDef& server_def);
+ void SetLogging(bool active);
+
+ void RetrieveLogs(tensorflow::int64 step_id, LoggingResponse* response);
+
+ void ClearLogs();
+
private:
const WorkerEnv* const worker_env_; // Not owned.
@@ -75,6 +82,8 @@ class SessionMgr {
std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
std::shared_ptr<WorkerSession> legacy_session_;
+ bool is_logging_active_ = false;
+
const WorkerCacheFactory worker_cache_factory_;
std::shared_ptr<WorkerSession> WorkerSessionForSessionUnlocked(
@@ -87,4 +96,4 @@ class SessionMgr {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h
index c62347926f..62fa5f3cf5 100644
--- a/tensorflow/core/distributed_runtime/worker.h
+++ b/tensorflow/core/distributed_runtime/worker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
#include <unordered_map>
@@ -120,4 +120,4 @@ class Worker : public WorkerInterface {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_
diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h
index 9da3bb253f..0fd19ac27f 100644
--- a/tensorflow/core/distributed_runtime/worker_session.h
+++ b/tensorflow/core/distributed_runtime/worker_session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
#include <string>
@@ -61,4 +61,4 @@ struct WorkerSession {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_
diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h
index 69955ec4cb..3d06bd55e2 100644
--- a/tensorflow/core/example/example_parser_configuration.h
+++ b/tensorflow/core/example/example_parser_configuration.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
+#ifndef TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
+#define TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
#include <string>
#include <vector>
@@ -53,4 +53,4 @@ Status ExampleParserConfigurationProtoToFeatureVectors(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_
+#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index c0deb473a2..293c40e04d 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
#include <array>
@@ -287,4 +287,4 @@ Status ExplicitShape(InferenceContext* c);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#endif // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
new file mode 100644
index 0000000000..2c2c7e7c58
--- /dev/null
+++ b/tensorflow/core/framework/dataset.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_DATASET_H_
+#define TENSORFLOW_FRAMEWORK_DATASET_H_
+
+namespace tensorflow {
+namespace dataset {
+// Registry for stateful ops that need to be used in dataset functions.
+// See below macro for usage details.
+class WhitelistedStatefulOpRegistry {
+ public:
+ Status Add(StringPiece op_name) {
+ op_names_.insert(op_name);
+ return Status::OK();
+ }
+
+ bool Contains(StringPiece op_name) {
+ return op_names_.find(op_name) != op_names_.end();
+ }
+
+ static WhitelistedStatefulOpRegistry* Global() {
+ static WhitelistedStatefulOpRegistry* reg =
+ new WhitelistedStatefulOpRegistry;
+ return reg;
+ }
+
+ private:
+ WhitelistedStatefulOpRegistry() {}
+ WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+ WhitelistedStatefulOpRegistry operator=(
+ WhitelistedStatefulOpRegistry const& copy);
+ std::set<StringPiece> op_names_;
+};
+
+} // namespace dataset
+
+// Use this macro to whitelist an op that is marked stateful but needs to be
+// used inside a map_fn in an input pipeline. This is only needed if you wish
+// to be able to checkpoint the state of the input pipeline. We currently
+// do not allow stateful ops to be defined inside of map_fns since it is not
+// possible to save their state.
+// Note that the state of the whitelisted ops inside functions will not be
+// saved during checkpointing, hence this should only be used if the op is
+// marked stateful for reasons like to avoid constant folding during graph
+// optimiztion but is not stateful.
+// If possible, try to remove the stateful flag on the op first.
+// Example usage:
+//
+// WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader");
+//
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \
+ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
+ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
+ static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
+ name)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index e78b6ab5d9..870bbb141b 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -266,35 +266,6 @@ static void StringReplace(const string& from, const string& to, string* s) {
*s = str_util::Join(split, to.c_str());
}
-static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
- const string from_quoted = strings::StrCat("`", from, "`");
- const string to_quoted = strings::StrCat("`", to, "`");
- for (int i = 0; i < op_def->input_arg_size(); ++i) {
- if (!op_def->input_arg(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_input_arg(i)->mutable_description());
- }
- }
- for (int i = 0; i < op_def->output_arg_size(); ++i) {
- if (!op_def->output_arg(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_output_arg(i)->mutable_description());
- }
- }
- for (int i = 0; i < op_def->attr_size(); ++i) {
- if (!op_def->attr(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_attr(i)->mutable_description());
- }
- }
- if (!op_def->summary().empty()) {
- StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
- }
- if (!op_def->description().empty()) {
- StringReplace(from_quoted, to_quoted, op_def->mutable_description());
- }
-}
-
static void RenameInDocs(const string& from, const string& to,
ApiDef* api_def) {
const string from_quoted = strings::StrCat("`", from, "`");
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c879dc6f3f..16bf5c256f 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -942,13 +943,6 @@ Status FindKernelRegistration(const DeviceType& device_type,
return Status::OK();
}
-Status FindKernelRegistration(const DeviceType& device_type, const Node& node,
- const KernelRegistration** reg,
- bool* was_attr_mismatch) {
- return FindKernelRegistration(device_type, node.def(), reg,
- was_attr_mismatch);
-}
-
} // namespace
// TODO(irving): Change const NodeDef& to const Node&
@@ -1162,24 +1156,51 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
}
#endif
-void OpKernelConstruction::CtxFailure(Status s) {
+void OpKernelConstruction::CtxFailure(const Status& s) {
VLOG(1) << s;
SetStatus(s);
}
-void OpKernelConstruction::CtxFailureWithWarning(Status s) {
+void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
LOG(WARNING) << s;
SetStatus(s);
}
-void OpKernelContext::CtxFailure(Status s) {
+void OpKernelConstruction::CtxFailure(const char* file, int line,
+ const Status& s) {
+ VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
+}
+
+void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
+ const Status& s) {
+ LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
+}
+
+void OpKernelContext::CtxFailure(const Status& s) {
VLOG(1) << s;
SetStatus(s);
}
-void OpKernelContext::CtxFailureWithWarning(Status s) {
+void OpKernelContext::CtxFailureWithWarning(const Status& s) {
LOG(WARNING) << s;
SetStatus(s);
}
+void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
+ VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
+}
+
+void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
+ const Status& s) {
+ LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 25150499ad..b72f1405cf 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -316,8 +316,10 @@ class OpKernelConstruction {
int graph_def_version() const { return graph_def_version_; }
// Helper routines for the OP_REQUIRES macros
- void CtxFailure(Status s);
- void CtxFailureWithWarning(Status s);
+ void CtxFailure(const Status& s);
+ void CtxFailureWithWarning(const Status& s);
+ void CtxFailure(const char* file, int line, const Status& s);
+ void CtxFailureWithWarning(const char* file, int line, const Status& s);
// Unrecommended functions: these are functions that have some
// current uses but are not recommended for use, and may go away at
@@ -1014,8 +1016,10 @@ class OpKernelContext {
}
// Helper routines for the OP_REQUIRES macros
- void CtxFailure(Status s);
- void CtxFailureWithWarning(Status s);
+ void CtxFailure(const Status& s);
+ void CtxFailureWithWarning(const Status& s);
+ void CtxFailure(const char* file, int line, const Status& s);
+ void CtxFailureWithWarning(const char* file, int line, const Status& s);
// Unrecommended functions: these are functions that have some
// current uses but are not recommended for use, and may go away at
@@ -1476,40 +1480,40 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
// ...
// }
-#define OP_REQUIRES(CTX, EXP, STATUS) \
- do { \
- if (!TF_PREDICT_TRUE(EXP)) { \
- (CTX)->CtxFailure((STATUS)); \
- return; \
- } \
+#define OP_REQUIRES(CTX, EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
+ (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
+ return; \
+ } \
} while (0)
-#define OP_REQUIRES_OK(CTX, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(_s); \
- return; \
- } \
+#define OP_REQUIRES_OK(CTX, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return; \
+ } \
} while (0)
-#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
- do { \
- if (!TF_PREDICT_TRUE(EXP)) { \
- (CTX)->CtxFailure((STATUS)); \
- (CALLBACK)(); \
- return; \
- } \
+#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
+ (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
+ (CALLBACK)(); \
+ return; \
+ } \
} while (0)
-#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
- do { \
- ::tensorflow::Status _s(STATUS); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(_s); \
- (CALLBACK)(); \
- return; \
- } \
+#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ (CALLBACK)(); \
+ return; \
+ } \
} while (0)
} // namespace tensorflow
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index edc93aec7f..e062adffe8 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -53,7 +53,7 @@ limitations under the License.
*/
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \
- defined(NVIDIA_TEGRA)
+ defined(ANDROID_TEGRA)
// All types are supported, so all macros are invoked.
//
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 4a4ef12635..d552ec1693 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
#include <vector>
@@ -787,4 +787,4 @@ Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
} // namespace shape_inference
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index fbfd24538b..7977841482 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
#include <vector>
#include "tensorflow/core/framework/node_def.pb.h"
@@ -98,4 +98,4 @@ class ShapeInferenceTestutil {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 4f08cdc1d7..77a3edcc10 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -615,11 +615,11 @@ void Tensor::CheckType(DataType expected_dtype) const {
void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
CHECK_EQ(dtype(), expected_dtype);
- CHECK(IsAligned());
+ CHECK(IsAligned()) << "CheckTypeAndIsAligned";
}
void Tensor::CheckIsAlignedAndSingleElement() const {
- CHECK(IsAligned());
+ CHECK(IsAligned()) << "Aligned and single element";
CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 92d10f0d8c..94c39c53a6 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -660,7 +660,8 @@ void Tensor::FillDimsAndValidateCompatibleShape(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
gtl::ArraySlice<int64> new_sizes) {
- CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
+ CheckType(DataTypeToEnum<T>::v());
+ CHECK(IsAligned());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(new_sizes, &dims);
return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
@@ -687,7 +688,8 @@ typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
gtl::ArraySlice<int64> new_sizes) const {
- CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
+ CheckType(DataTypeToEnum<T>::v());
+ CHECK(IsAligned());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(new_sizes, &dims);
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index cb8e77f1df..ded6aa0991 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -453,6 +453,13 @@ inline bool DataTypeIsInteger(DataType dt) {
return kDataTypeIsInteger.Contains(dt);
}
+// Is the dtype a signed integral type?
+constexpr DataTypeSet kDataTypeIsSigned =
+ ToSet(DT_INT8) | ToSet(DT_INT16) | ToSet(DT_INT32) | ToSet(DT_INT64);
+inline bool DataTypeIsSigned(DataType dt) {
+ return kDataTypeIsSigned.Contains(dt);
+}
+
// Is the dtype an unsigned integral type?
constexpr DataTypeSet kDataTypeIsUnsigned =
ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_UINT32) | ToSet(DT_UINT64);
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index 395329da3b..ee07db1aee 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -182,7 +182,7 @@ Status VariantDeviceCopy(
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find(std::make_tuple(op, device, type_name));
+ auto found = unary_op_fns.find({op, device, type_name});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
@@ -195,12 +195,10 @@ void UnaryVariantOpRegistry::RegisterUnaryOpFn(
CHECK_EQ(existing, nullptr)
<< "Unary VariantUnaryOpFn for type_name: " << type_name
<< " already registered for device type: " << device;
- unary_op_fns.insert(
- std::pair<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
- VariantUnaryOpFn>(
- std::make_tuple(op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)),
- unary_op_fn));
+ unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
+ {op, GetPersistentStringPiece(device),
+ GetPersistentStringPiece(type_name)},
+ unary_op_fn));
}
namespace {
@@ -229,7 +227,7 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
StringPiece type_name) {
- auto found = binary_op_fns.find(std::make_tuple(op, device, type_name));
+ auto found = binary_op_fns.find({op, device, type_name});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
@@ -242,12 +240,10 @@ void UnaryVariantOpRegistry::RegisterBinaryOpFn(
CHECK_EQ(existing, nullptr)
<< "Unary VariantBinaryOpFn for type_name: " << type_name
<< " already registered for device type: " << device;
- binary_op_fns.insert(
- std::pair<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
- VariantBinaryOpFn>(
- std::make_tuple(op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)),
- add_fn));
+ binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
+ {op, GetPersistentStringPiece(device),
+ GetPersistentStringPiece(type_name)},
+ add_fn));
}
namespace {
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index 13f6908cae..0e2a410429 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -166,6 +166,21 @@ class UnaryVariantOpRegistry {
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
+
+ // this breaks by falling victim to "too perfect forwarding"
+ // see https://stackoverflow.com/questions/44475317/variadic-template-issue
+ // and references therein
+ template <typename Op>
+ struct FuncTuple {
+ FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
+ : op_type_(op), device_(dev), typename_(tname){};
+ Op op_type_;
+ StringPiece device_, typename_;
+ };
+ //friend declaration for operator==
+ // needed for clang
+ template <typename Op>
+ friend bool operator==(const FuncTuple<Op> &l, const FuncTuple<Op> &r);
struct TupleHash {
template <typename Op>
std::size_t operator()(
@@ -176,18 +191,24 @@ class UnaryVariantOpRegistry {
ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
return ret;
}
+
+ template <typename Op>
+ std::size_t operator()(const FuncTuple<Op>& x) const {
+ // The hash of an enum is just its value as a std::size_t.
+ std::size_t ret = static_cast<std::size_t>(x.op_type_);
+ ret = Hash64Combine(ret, sp_hasher_(x.device_));
+ ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ return ret;
+ }
StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
- VariantUnaryOpFn, TupleHash>
+ std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
- VariantBinaryOpFn, TupleHash>
+ std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
- // container; return the StringPiece pointing to the permanent
- // string location.
+ // container; return the StringPiece pointing to the permanent string location.
static StringPiece GetPersistentStringPiece(const string& str) {
const auto string_storage = PersistentStringStorage();
auto found = string_storage->find(str);
@@ -199,7 +220,12 @@ class UnaryVariantOpRegistry {
}
}
};
-
+template <typename Op>
+inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
+ const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
+ return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
+ (lhs.typename_ == rhs.typename_);
+}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
// function, or if it's a serialized Variant that cannot be decoded.
@@ -283,8 +309,8 @@ Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
"enum: ",
- op, " Variant type_name: '", a.TypeName(),
- "' for device type: ", device);
+ op, " Variant type_name: '", a.TypeName(), "' for device type: ",
+ device);
}
return (*binary_op_fn)(ctx, a, b, out);
}
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
index b1e6cf64e8..4118f14f8b 100644
--- a/tensorflow/core/graph/costmodel.cc
+++ b/tensorflow/core/graph/costmodel.cc
@@ -57,10 +57,10 @@ void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
const int local_id = cm.Id(n);
const int global_id = Id(n);
if (local_id < 0 || global_id < 0) continue;
- Ensure(global_id);
+ int num_slots = cm.slot_bytes_[local_id].size();
+ Ensure(global_id, num_slots);
count_[global_id] += cm.count_[local_id];
time_[global_id] += cm.time_[local_id];
- int num_slots = cm.slot_bytes_[local_id].size();
if (num_slots > 0) {
if (slot_bytes_[global_id].empty()) {
slot_bytes_[global_id].resize(num_slots);
@@ -78,11 +78,11 @@ void CostModel::MergeFromGlobal(const CostModel& cm) {
CHECK(is_global_);
CHECK_EQ(true, cm.is_global());
const int num_nodes = cm.count_.size();
- Ensure(num_nodes);
- for (int i = 0; i < num_nodes; ++i) {
+ for (int i = num_nodes - 1; i >= 0; --i) {
count_[i] += cm.count_[i];
time_[i] += cm.time_[i];
int num_slots = cm.slot_bytes_[i].size();
+ Ensure(i, num_slots);
if (num_slots > 0) {
if (slot_bytes_[i].empty()) {
slot_bytes_[i].resize(num_slots);
@@ -106,7 +106,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
// copy/send/recv nodes, feed/fetch, etc.
if (iter == map.end()) continue;
int32 global_id = iter->second;
- Ensure(global_id);
+ Ensure(global_id, ns.output_size());
int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
count_[global_id]++;
time_[global_id] += elapsed_micros;
@@ -122,7 +122,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
}
}
-void CostModel::Ensure(int id) {
+void CostModel::Ensure(int id, int num_outputs) {
if (slot_bytes_.size() <= static_cast<size_t>(id)) {
slot_bytes_.resize(id + 1);
count_.resize(id + 1);
@@ -131,25 +131,37 @@ void CostModel::Ensure(int id) {
max_exec_time_.resize(id + 1);
output_port_alloc_ids_.resize(id + 1);
}
+ if (num_outputs > 0) {
+ auto perslot = &slot_bytes_[id];
+ auto output_port_alloc_ids = &output_port_alloc_ids_[id];
+ auto max_mem_usage = &max_mem_usage_[id];
+
+ CHECK_LE(perslot->size(), num_outputs);
+ DCHECK_EQ(output_port_alloc_ids->size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size());
+
+ perslot->resize(num_outputs, Bytes(-1));
+ output_port_alloc_ids->resize(num_outputs, -1);
+ max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
+ max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
+ max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
+ }
}
void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ // Do not resize the number of slots before checking its existing number of
+ // slots.
+ Ensure(id, 0);
auto perslot = &slot_bytes_[id];
- auto max_mem_usage = &max_mem_usage_[id];
- auto output_port_alloc_ids = &output_port_alloc_ids_[id];
if (!perslot->empty()) {
CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
<< node->name();
- } else {
- perslot->resize(num_outputs, Bytes(-1));
- output_port_alloc_ids->resize(num_outputs, -1);
- max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
- max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
- max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
}
+ Ensure(id, num_outputs);
}
void CostModel::RecordCount(const Node* node, int count) {
@@ -198,7 +210,7 @@ void CostModel::RecordTime(const Node* node, Microseconds time) {
const int id = Id(node);
if (id < 0) return;
DCHECK(node->IsOp()) << node->DebugString();
- Ensure(id);
+ Ensure(id, node->num_outputs());
time_[id] += time;
}
@@ -240,7 +252,10 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
const DataType& dtype) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ CHECK_LT(output_slot, node->num_outputs())
+ << "Unexpected output slot for node " << node->DebugString() << ". Got "
+ << output_slot << " but its num_outputs is " << node->num_outputs();
+ Ensure(id, node->num_outputs());
auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
// If the memory allocator doesn't track memory usage, let's infer a lower
// bound from the tensor shape and its data type.
@@ -316,7 +331,7 @@ void CostModel::RecordMemoryStats(const Node* node,
void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ Ensure(id, node->num_outputs());
max_exec_time_[id] = std::max(max_exec_time_[id], time);
}
@@ -332,7 +347,7 @@ void CostModel::RecordAllocationId(const Node* node, int output_slot,
int64 alloc_id) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ Ensure(id, node->num_outputs());
output_port_alloc_ids_[id][output_slot] = alloc_id;
}
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index 081eb2ff4c..c60a946c2c 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -183,8 +183,8 @@ class CostModel {
const bool is_global_;
- // Resizes vectors so that they are large enough for "id".
- void Ensure(int id);
+ // Resizes vectors so that they are large enough for "id" and id's outputs.
+ void Ensure(int id, int num_outputs);
// Nodes and Edges whose count is < this value
// get type/byte estimates of 0.
diff --git a/tensorflow/core/graph/gradients.h b/tensorflow/core/graph/gradients.h
index 75906e6ce9..ddfed084b0 100644
--- a/tensorflow/core/graph/gradients.h
+++ b/tensorflow/core/graph/gradients.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
+#define TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
@@ -55,4 +55,4 @@ Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 852e69737b..b7eaf8dc63 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -85,10 +85,7 @@ struct Costs {
typedef NanoSeconds Duration;
// Overall cost of running the graph; latency.
- // Mean
Duration execution_time;
- Duration min_execution_time;
- Duration max_execution_time;
// Computation cost of running the graph.
Duration compute_time;
diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
index 8fd1801863..ea4320687a 100644
--- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
@@ -117,8 +117,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
LOG(ERROR) << "Failed to measure graph performance: "
<< status.error_message();
costs->execution_time = Costs::Duration::max();
- costs->max_execution_time = Costs::Duration::max();
- costs->min_execution_time = 0;
return status;
}
@@ -126,8 +124,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
// to filter out outliers.
RobustStats stats(times);
costs->execution_time = Costs::Duration(stats.mean());
- costs->max_execution_time = Costs::Duration(stats.hi());
- costs->min_execution_time = Costs::Duration(stats.lo());
return Status::OK();
}
diff --git a/tensorflow/core/grappler/costs/op_context.h b/tensorflow/core/grappler/costs/op_context.h
index 735a1e68ea..6391de4a91 100644
--- a/tensorflow/core/grappler/costs/op_context.h
+++ b/tensorflow/core/grappler/costs/op_context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
+#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -36,4 +36,4 @@ struct OpContext {
} // end namespace grappler
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_
diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto
index 1d623b8db8..37f9ebd6a1 100644
--- a/tensorflow/core/grappler/costs/op_performance_data.proto
+++ b/tensorflow/core/grappler/costs/op_performance_data.proto
@@ -58,11 +58,18 @@ message LogNormalDistribution {
double sigma = 2;
}
+message SessionInfo {
+ int64 intra_op_parallelism = 1;
+}
+
// Performance data for tensorflow operations
message OpPerformance {
// The op
OpInfo op = 1;
+ // Information about the session configs.
+ SessionInfo session_info = 12;
+
// The node name (optional). Makes it easier to associate the performance data
// with a specific graph node.
string node = 5;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index c180250908..9db6d46266 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
+#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
#include <list>
#include <memory>
@@ -139,8 +139,8 @@ class FIFOManager : public ReadyNodeManager {
public:
FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {}
- virtual void Init(
- const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
+ void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
+ override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override {
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
@@ -342,4 +342,4 @@ class VirtualScheduler {
} // namespace grappler
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 149f6fc735..2f8549cf39 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -134,6 +134,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
const NodeDef* node = name_to_node[NodeName(root)];
if (!node) {
*ill_formed = true;
+ VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
return {};
}
queue.push_back(node);
@@ -153,6 +154,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
if (!in) {
+ VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
*ill_formed = true;
return {};
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 990a07c86c..9c544c82bf 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -436,23 +436,31 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
return true;
}
+NodeDef* ArithmeticOptimizer::AddNode(const NodeDef& node, StringPiece suffix,
+ bool copy_node) {
+ return AddNode(OptimizedNodeName(node, suffix), copy_node ? &node : nullptr);
+}
+
NodeDef* ArithmeticOptimizer::AddNode(const string& name,
const NodeDef* node_to_copy) {
NodeDef* new_node = optimized_graph_->add_node();
- const string name_with_prefix =
- AddPrefixToNodeName(name, kArithmeticOptimizer);
- node_map_->AddNode(NodeName(name_with_prefix), new_node);
+ node_map_->AddNode(NodeName(name), new_node);
if (node_to_copy != nullptr) {
*new_node = *node_to_copy;
}
- new_node->set_name(name_with_prefix);
+ new_node->set_name(name);
return new_node;
}
-bool ArithmeticOptimizer::OptimizedNodeExists(const string& name) {
- const string name_with_prefix =
- AddPrefixToNodeName(name, kArithmeticOptimizer);
- return node_map_->NodeExists(name_with_prefix);
+string ArithmeticOptimizer::OptimizedNodeName(const NodeDef& node,
+ StringPiece suffix) const {
+ return AddPrefixToNodeName(strings::StrCat(node.name(), "_", suffix),
+ kArithmeticOptimizer);
+}
+
+bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node,
+ StringPiece suffix) const {
+ return node_map_->NodeExists(OptimizedNodeName(node, suffix));
}
bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
@@ -668,17 +676,19 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const DataType src_type = GetSourceDataType(*cast);
const DataType dst_type = GetDestinationDataType(*cast);
if (IsNumberType(src_type) && IsNumberType(dst_type) &&
- DataTypeSize(src_type) < DataTypeSize(dst_type)) {
- NodeDef* new_transpose =
- AddNode(StrCat(transpose->name(), "_", DataTypeString(src_type)),
- transpose);
+ DataTypeSize(src_type) < DataTypeSize(dst_type) &&
+ !OptimizedNodeExists(*cast, DataTypeString(dst_type)) &&
+ !OptimizedNodeExists(*transpose, DataTypeString(src_type))) {
+ NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type),
+ /*copy_node=*/true);
(*new_transpose->mutable_attr())["T"].set_type(src_type);
new_transpose->set_input(0, cast->input(0));
node_map_->AddOutput(input->name(), new_transpose->name());
node_map_->AddOutput(NodeName(new_transpose->input(1)),
new_transpose->name());
- NodeDef* new_cast = AddNode(StrCat(cast->name(), "_new"), cast);
+ NodeDef* new_cast =
+ AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true);
new_cast->set_input(0, new_transpose->name());
node_map_->AddOutput(new_transpose->name(), new_cast->name());
@@ -754,7 +764,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// multiply can be constant-folded. TODO(jingyue): When the weights aren't
// constant, this should also help performance a bit and memory usage a lot,
// since the weights tend to be smaller than the activations.
- if (weights->op() == "Const") {
+ if (weights->op() == "Const" &&
+ !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) {
const NodeDef* source = node_map_->GetNode(
GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_)
->input(0));
@@ -773,7 +784,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
scale_tensor.tensor_shape().dim_size() == 0) {
// Create new node `scaled_weights`.
NodeDef* scaled_weights = AddNode(
- StrCat(weights->name(), "_scaled_", conv->name()), nullptr);
+ *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false);
scaled_weights->set_op("Mul");
scaled_weights->set_device(weights->device());
(*scaled_weights->mutable_attr())["T"] =
@@ -810,9 +821,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
if (node->op() == "Mul" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(StrCat(node->name(), "_square"))) {
- NodeDef* new_square_node =
- AddNode(strings::StrCat(node->name(), "_square"), node);
+ !OptimizedNodeExists(*node, "square")) {
+ NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true);
new_square_node->set_op("Square");
for (int i = 1; i < new_square_node->input_size(); ++i) {
new_square_node->set_input(i - 1, new_square_node->input(i));
@@ -847,8 +857,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
break;
}
}
- const string mul_node_name = StrCat(node->name(), "_mul");
- if (all_equal && !OptimizedNodeExists(mul_node_name)) {
+ if (all_equal && !OptimizedNodeExists(*node, "const") &&
+ !OptimizedNodeExists(*node, "mul")) {
// 1. Create constant node with value N.
const auto type = GetDataTypeFromAttr(*node, "T");
Tensor t(type, TensorShape({}));
@@ -859,15 +869,14 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return "";
}
TensorValue value(&t);
- NodeDef* new_const_node =
- AddNode(StrCat(node->name(), "_const"), nullptr);
+ NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false);
*new_const_node =
ConstantFolding::CreateNodeDef(new_const_node->name(), value);
new_const_node->set_device(node->device());
nodes_to_simplify->PushBack(new_const_node);
// 2. Replace the aggregate node with Mul(Const(N), x).
- NodeDef* new_mul_node = AddNode(mul_node_name, nullptr);
+ NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false);
new_mul_node->set_op("Mul");
new_mul_node->set_device(node->device());
SetDataTypeToAttr(type, "T", new_mul_node);
@@ -892,7 +901,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// to the following:
// Mul(x, AddN(y1, y2, y3, ... yn))
if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
- !OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) {
+ !OptimizedNodeExists(*node, "hoist_add") &&
+ !OptimizedNodeExists(*node, "hoist_mul")) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
for (int i = 0; i < node->input_size(); ++i) {
@@ -946,10 +956,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
if (shapes_match) {
// 1. Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"),
+ NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"),
node_map_->GetNode(node->input(0)));
- NodeDef* new_add_node =
- AddNode(StrCat(node->name(), "_hoist_add"), node);
+ NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true);
new_mul_node->set_device(node->device());
new_mul_node->set_input(0, common_factor);
node_map_->AddOutput(common_factor, new_mul_node->name());
@@ -978,7 +987,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// Fold Transpose into matrix multiplication.
if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
node->op() == "BatchMatMul") &&
- !OptimizedNodeExists(StrCat(node->name(), "_fused"))) {
+ !OptimizedNodeExists(*node, "fused")) {
const NodeDef* a = node_map_->GetNode(node->input(0));
const NodeDef* b = node_map_->GetNode(node->input(1));
bool is_complex = false;
@@ -996,7 +1005,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
IsInnerMatrixTransposeNode(*b, node_map_.get());
if (a_is_foldable || b_is_foldable) {
- NodeDef* new_op = AddNode(StrCat(node->name(), "_fused"), node);
+ NodeDef* new_op = AddNode(*node, "fused", /*copy_node=*/true);
if (a_is_foldable) {
const string attr_a =
node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
@@ -1021,7 +1030,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// Fold Conj into Transpose or ConjugateTranspose.
if ((node->op() == "Conj" || node->op() == "Transpose" ||
node->op() == "ConjugateTranspose") &&
- !OptimizedNodeExists(StrCat(node->name(), "_fused"))) {
+ !OptimizedNodeExists(*node, "fused")) {
const NodeDef* input = node_map_->GetNode(node->input(0));
const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
const NodeDef* conj_op = node->op() == "Conj" ? node : input;
@@ -1029,7 +1038,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
if ((transpose_op->op() == "Transpose" ||
transpose_op->op() == "ConjugateTranspose") &&
conj_op->op() == "Conj") {
- NodeDef* new_op = AddNode(StrCat(node->name(), "_fused"), transpose_op);
+ NodeDef* new_op =
+ AddNode(OptimizedNodeName(*node, "fused"), transpose_op);
// Flip the type of transpose op to absorb the conjugation.
new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
: "Transpose");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index ec26979238..afd538db40 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -48,7 +48,13 @@ class ArithmeticOptimizer : public GraphOptimizer {
private:
// Returns true is a node with given name and the optimizer prefix already
// exists.
- bool OptimizedNodeExists(const string& name);
+ string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
+ bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
+
+ // Creates a new node in the graph, with name equal to that of node, prefixed
+ // with "ArithmeticOptimizer/" and the given suffix. Also updates node_map_,
+ // and optionally copies node into the new node if copy_node is true.
+ NodeDef* AddNode(const NodeDef& node, StringPiece suffix, bool copy_node);
// Creates a new node in the graph, prefixed with "ArithmeticOptimizer/",
// updates node_map_, and optionally copies *node_to_copy into the new
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index b5b1ec7021..2a82b25058 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -627,7 +627,7 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
@@ -651,7 +651,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
@@ -673,7 +673,7 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
@@ -706,7 +706,7 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
@@ -730,7 +730,7 @@ TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
const NodeDef* transpose_node = nullptr;
@@ -766,7 +766,7 @@ TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
int num_transposes = 0;
@@ -800,7 +800,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
std::set<string> nodes_after_optimization;
@@ -833,7 +833,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
for (const NodeDef& node : output.node()) {
@@ -860,7 +860,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -889,7 +889,7 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(6, output.node_size());
@@ -920,7 +920,7 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -962,7 +962,7 @@ TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -992,7 +992,7 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -1031,11 +1031,15 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+
+ item.graph.Swap(&output);
TF_EXPECT_OK(
ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -1043,7 +1047,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
const NodeDef* transpose_node =
CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8")));
const NodeDef* cast_node =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_new")));
+ CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float")));
const NodeDef* weights_node =
CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D")));
const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
@@ -1080,11 +1084,11 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(
ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
NodeMap node_map(&output);
@@ -1113,7 +1117,7 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(1, std::count_if(
@@ -1133,7 +1137,7 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
@@ -1152,7 +1156,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph = output;
+ item.graph.Swap(&output);
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
EXPECT_EQ(0, std::count_if(
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 6860447fb8..0aeff6222c 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -128,6 +128,42 @@ bool AllValuesAre(const TensorProto& tensor, const T& value) {
return false;
}
+// Add new_input as a control input to node if it does not already depend on it.
+// TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
+// clean up code that should be using them.
+bool MaybeAddControlInput(const string& new_input, NodeDef* node,
+ GraphDef* graph, NodeMap* node_map) {
+ bool already_exists = false;
+ for (const string& input : node->input()) {
+ if (input == new_input || AsControlDependency(input) == new_input) {
+ already_exists = true;
+ break;
+ }
+ }
+ if (!already_exists) {
+ const string ctrl_dep =
+ ConstantFolding::AddControlDependency(new_input, graph, node_map);
+ node->add_input(ctrl_dep);
+ node_map->AddOutput(NodeName(new_input), node->name());
+ }
+ return !already_exists;
+}
+
+// Remove old_input as a control input to node.
+bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
+ GraphDef* graph, NodeMap* node_map) {
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input = node->input(i);
+ if (IsControlInput(input) && AsControlDependency(old_input) == input) {
+ node->mutable_input()->SwapElements(i, node->input_size() - 1);
+ node->mutable_input()->RemoveLast();
+ node_map->RemoveOutput(NodeName(old_input), node->name());
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@@ -1524,14 +1560,15 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
//
// + + = parent
// / \ / \
- // Const + -- > X + = children
+ // C + -- > X + = children
// / \ / \
- // X Y Const Y = leaves
+ // X Y C Y = leaves
//
- // where '+' denotes an associative and commutative operator like addition
- // or multiplication. This optimization pushes constants down in the tree
- // to canonicalize it. Moreoever, in cases where the child node has a
- // constant input we will create a node that can be folded, e.g.
+ // where C is constant and X is non-constant, and '+' denotes an
+ // associative and commutative operator like addition or multiplication.
+ // This optimization pushes constants down in the tree to canonicalize it.
+ // Moreoever, in cases where the child node has a second constant input Y
+ // we will create a leaf node that can be folded, e.g.
//
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
//
@@ -1540,7 +1577,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
// division/multiplication.
// Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
- if ((IsAdd(*node) || is_mul) && NumNonControlInputs(*node) == 2) {
+ if (has_fetch_ && (IsAdd(*node) || is_mul) &&
+ NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
@@ -1556,18 +1594,21 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
node->device() != right_child->device()) {
continue;
}
- NodeDef* child_node = left_child_is_constant ? right_child : left_child;
+ NodeDef* op_child_node =
+ left_child_is_constant ? right_child : left_child;
+ NodeDef* const_child_node =
+ left_child_is_constant ? left_child : right_child;
// Make sure that it is safe to change the value of the child node->
- if (child_node->input_size() < 2 ||
- NumNonControlOutputs(*child_node, *node_map_) > 1 || !has_fetch_ ||
- nodes_to_preserve_.find(child_node->name()) !=
+ if (op_child_node->input_size() < 2 ||
+ NumNonControlOutputs(*op_child_node, *node_map_) > 1 ||
+ nodes_to_preserve_.find(op_child_node->name()) !=
nodes_to_preserve_.end()) {
continue;
}
// Identify the nodes to swap.
- const NodeDef* left_leaf = node_map_->GetNode(child_node->input(0));
- const NodeDef* right_leaf = node_map_->GetNode(child_node->input(1));
+ NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
+ NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
@@ -1576,15 +1617,27 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
}
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
const int parent_const_input = left_child_is_constant ? 0 : 1;
+ const auto& child_output = node_map_->GetOutputs(op_child_node->name());
+ if (child_output.find(const_child_node) != child_output.end()) {
+ // If there is a control edge from the child op to C, the transformation
+ // would create a cycle in the graph. We know that it must be a control
+ // edge. We can replace such a control edge with a control edge from A
+ // to C.
+ CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
+ graph_, node_map_.get()));
+ NodeDef* other_leaf = left_leaf_is_constant ? left_leaf : right_leaf;
+ MaybeAddControlInput(other_leaf->name(), const_child_node, graph_,
+ node_map_.get());
+ }
// Swap the constant child with a non-constant leaf node.
node_map_->UpdateInput(node->name(), node->input(parent_const_input),
- child_node->input(non_const_leaf_input));
- node_map_->UpdateInput(child_node->name(),
- child_node->input(non_const_leaf_input),
+ op_child_node->input(non_const_leaf_input));
+ node_map_->UpdateInput(op_child_node->name(),
+ op_child_node->input(non_const_leaf_input),
node->input(parent_const_input));
std::swap(*node->mutable_input(parent_const_input),
- *child_node->mutable_input(non_const_leaf_input));
+ *op_child_node->mutable_input(non_const_leaf_input));
graph_modified_ = true;
}
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 6aadd97508..18acc91e8a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -52,7 +52,6 @@ class ConstantFolding : public GraphOptimizer {
private:
string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
- string OptimizedNodeName(const NodeDef& node) const;
bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
bool IsReallyConstant(const NodeDef& node) const;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2db3dc6993..849a88770a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -80,18 +80,25 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
TEST_F(ConstantFoldingTest, AddTree) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output c1 = ops::Const(s.WithOpName("c1"), 2.0f, {1});
Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
- Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
+ Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
+ Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
+ 1.0f, {1});
Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
- Output mul_child = ops::Mul(s.WithOpName("mul_child"), c2, x);
- Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c1, mul_child);
- Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c2, x);
+
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
+ Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
+ Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
+ Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
+ Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
+ Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
Output addmul_parent =
- ops::Mul(s.WithOpName("addmul_parent"), c1, addmul_child);
+ ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
GrapplerItem item;
item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
@@ -102,15 +109,21 @@ TEST_F(ConstantFoldingTest, AddTree) {
Status status = fold.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(9, output.node_size());
-
- // We expect the following rewrite(s) to occur (for both Add and Mul):
+ // We expect the following rewrite(s) to occur:
+ //
// + + +
// / \ / \ / \
- // 2.0 + --> x + --> x 4.0
- // / \ / \
- // 2.0 x 2.0 2.0
+ // 1.0 + --> x + --> x 3.0
+ // / \ / \
+ // 2.0 x 1.0 2.0
+ //
+ // * * *
+ // / \ / \ / \
+ // 4.0 * --> y * --> y 20.0
+ // / \ / \
+ // 5.0 y 4.0 5.0
+ EXPECT_EQ(11, output.node_size());
for (const auto& node : output.node()) {
if (node.name() == "add_child") {
EXPECT_EQ("Const", node.op());
@@ -130,26 +143,26 @@ TEST_F(ConstantFoldingTest, AddTree) {
} else if (node.name() == "mul_parent") {
EXPECT_EQ("Mul", node.op());
EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("y", node.input(0));
EXPECT_EQ("mul_child", node.input(1));
} else if (node.name() == "addmul_child") {
// Unchanged.
EXPECT_EQ("Add", node.op());
EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("c2", node.input(0));
+ EXPECT_EQ("c4", node.input(0));
EXPECT_EQ("x", node.input(1));
}
}
- // Check that the reciprocals have the expected value.
- std::vector<string> fetch = {"c4"};
+ // Check that the result nodes have the expected value.
+ std::vector<string> fetch = {"c3", "c20"};
auto tensor_expected = EvaluateNodes(item.graph, fetch);
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"add_child", "mul_child"};
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
- test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
+ test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
}
}
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 1f68ecbade..d2da125236 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -58,11 +58,7 @@ void PruneControlInputs(NodeDef* node) {
int pos = 0;
while (pos < node->input_size()) {
const string& input = node->input(pos);
- // TODO(rmlarsen): Remove control inputs that also appears as a regular
- // inputs. Currently, doing so breaks testControlFlowStrictness in
- // python/framework/function_test.
- // if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
- if (IsControlInput(input) && !inputs.insert(input).second) {
+ if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
VLOG(1) << "**** Removing duplicate control input: " << input
<< " from node " << node->DebugString();
node->mutable_input()->SwapElements(pos, node->input_size() - 1);
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index 3f6f418bee..02d8a0f32a 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
@@ -73,4 +73,4 @@ class DependencyOptimizer : public GraphOptimizer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index ea7b05d381..735d78e7ee 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -590,7 +590,7 @@ class NodeProcessor : public GraphProcessor {
// to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *input_node;
- string base_name = strings::StrCat(node_->name(), "-", input_node->name());
+ string base_name = strings::StrCat(node_->name(), "-", input_index);
string node_name = LayoutOptimizerNode(base_name);
added_node->set_name(node_name);
*node_->mutable_input(input_index) = node_name;
@@ -1647,12 +1647,32 @@ class StridedSliceProcessor : public SliceProcessor {
return errors::InvalidArgument("invalid mask value: ", i);
}
if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
- if (i == 2 || i == 3) i += 2;
- if (i == 4 || i == 5) i += 4;
- if (i == 6 || i == 7) i += 6;
- if (i == 8 || i == 9) i -= 6;
- if (i == 10 || i == 11) i -= 4;
- if (i == 12 || i == 13) i -= 2;
+ switch (i) {
+ case 2:
+ case 3:
+ i += 2;
+ break;
+ case 4:
+ case 5:
+ i += 4;
+ break;
+ case 6:
+ case 7:
+ i += 6;
+ break;
+ case 8:
+ case 9:
+ i -= 6;
+ break;
+ case 10:
+ case 11:
+ i -= 4;
+ break;
+ case 12:
+ case 13:
+ i -= 2;
+ break;
+ }
node_->mutable_attr()->at(mask).set_i(i);
return Status::OK();
}
@@ -2056,6 +2076,7 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
const TuningConfig& config, GraphDef* output) {
auto status = graph_properties.AnnotateOutputShapes(output);
if (!status.ok()) {
+ VLOG(1) << "Annotate shape return status: " << status.ToString();
*output = item.graph;
return status;
}
@@ -2080,6 +2101,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphProperties graph_properties(item);
auto status = graph_properties.InferStatically(false);
if (!status.ok()) {
+ VLOG(1) << "Infer shape return status: " << status.ToString();
*output = item.graph;
return status;
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 587642c96e..5cb366df2d 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -172,8 +172,7 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
- string input_name =
- strings::StrCat("Conv2DBackpropInput-InputSizes", "-", "LayoutOptimizer");
+ string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
auto input_sizes_node = node_map.GetNode(input_name);
CHECK(input_sizes_node);
auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
@@ -288,7 +287,7 @@ TEST_F(LayoutOptimizerTest, Pad) {
auto pad = node_map.GetNode("p");
EXPECT_EQ(pad->input(0), "Conv2D");
- auto pad_const = node_map.GetNode("p-c-LayoutOptimizer");
+ auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
EXPECT_TRUE(pad_const);
EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
Tensor tensor;
@@ -476,9 +475,9 @@ TEST_F(LayoutOptimizerTest, SplitDimC) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
+ EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
+ auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
EXPECT_EQ(split_const->op(), "Const");
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
}
@@ -496,9 +495,9 @@ TEST_F(LayoutOptimizerTest, SplitDimH) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
+ EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
+ auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
EXPECT_EQ(split_const->op(), "Const");
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
}
@@ -516,9 +515,9 @@ TEST_F(LayoutOptimizerTest, SplitDimW) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
+ EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
+ auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
EXPECT_EQ(split_const->op(), "Const");
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
}
@@ -536,9 +535,9 @@ TEST_F(LayoutOptimizerTest, SplitDimN) {
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto split_node = node_map.GetNode("split");
- EXPECT_EQ(split_node->input(0), "split-c-LayoutOptimizer");
+ EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
EXPECT_EQ(split_node->input(1), "Conv2D");
- auto split_const = node_map.GetNode("split-c-LayoutOptimizer");
+ auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
EXPECT_EQ(split_const->op(), "Const");
EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
}
@@ -582,8 +581,8 @@ TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
EXPECT_EQ(concat_node->input(0), "split:1");
EXPECT_EQ(concat_node->input(1), "split:1");
EXPECT_EQ(concat_node->input(2), "split:1");
- EXPECT_EQ(concat_node->input(3), "concat-axis-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
+ EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
+ auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
}
@@ -603,8 +602,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimH) {
auto concat_node = node_map.GetNode("concat");
EXPECT_EQ(concat_node->input(0), "split");
EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
+ EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
+ auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
}
@@ -648,8 +647,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimW) {
auto concat_node = node_map.GetNode("concat");
EXPECT_EQ(concat_node->input(0), "split");
EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
+ EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
+ auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
}
@@ -669,8 +668,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimN) {
auto concat_node = node_map.GetNode("concat");
EXPECT_EQ(concat_node->input(0), "split");
EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
+ EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
+ auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
}
@@ -690,8 +689,8 @@ TEST_F(LayoutOptimizerTest, ConcatDimC) {
auto concat_node = node_map.GetNode("concat");
EXPECT_EQ(concat_node->input(0), "split");
EXPECT_EQ(concat_node->input(1), "split:1");
- EXPECT_EQ(concat_node->input(2), "concat-axis-LayoutOptimizer");
- auto concat_dim = node_map.GetNode("concat-axis-LayoutOptimizer");
+ EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
+ auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
}
@@ -861,10 +860,10 @@ TEST_F(LayoutOptimizerTest, SliceConst) {
NodeMap node_map(&output);
auto slice_node = node_map.GetNode("slice");
EXPECT_EQ(slice_node->input(0), "Conv2D");
- EXPECT_EQ(slice_node->input(1), "slice-begin-LayoutOptimizer");
- EXPECT_EQ(slice_node->input(2), "slice-size-LayoutOptimizer");
+ EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
+ EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
- auto begin_const = node_map.GetNode("slice-begin-LayoutOptimizer");
+ auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
Tensor begin_tensor;
EXPECT_TRUE(begin_tensor.FromProto(
begin_const->mutable_attr()->at({"value"}).tensor()));
@@ -872,7 +871,7 @@ TEST_F(LayoutOptimizerTest, SliceConst) {
test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
- auto size_const = node_map.GetNode("slice-size-LayoutOptimizer");
+ auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
Tensor size_tensor;
EXPECT_TRUE(size_tensor.FromProto(
size_const->mutable_attr()->at({"value"}).tensor()));
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 72791cbf6f..f537ecc41b 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -775,8 +775,10 @@ static const NodeDef* FindSwapInTrigger(
return nullptr;
}
-static bool IsSwappable(GraphView::OutputPort output) {
+static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
const NodeDef& node = *output.node;
+ // There is no point in swapping out persistent tensors, since the tensor will
+ // continue to use memory.
if (IsPersistent(node)) {
return false;
}
@@ -785,13 +787,29 @@ static bool IsSwappable(GraphView::OutputPort output) {
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
return false;
}
-
DataType dtype;
if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
return false;
}
+ // References can only refer to persistent memory: therefore the node isn't
+ // swappable.
+ if (IsRefType(dtype)) {
+ return false;
+ }
- return !IsRefType(dtype);
+ if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
+ // If placed on the same device, these nodes are just forwarding references
+ // to their input. Therefore they are swappable iff their fanin is swappable
+ // or it resides on a different device.
+ GraphView::InputPort input;
+ input.node = output.node;
+ input.port_id = 0;
+ GraphView::OutputPort fanin = graph.GetRegularFanin(input);
+ if (fanin.node->device() == node.device()) {
+ return IsSwappable(graph, fanin);
+ }
+ }
+ return true;
}
static NodeDef* FindSwapOutTrigger(
@@ -811,7 +829,7 @@ static NodeDef* FindSwapOutTrigger(
view.GetFanout(generator);
NodeDef* trigger = nullptr;
Costs::NanoSeconds earliest_fanout(
- static_cast<double>(std::numeric_limits<int>::max()));
+ static_cast<double>(std::numeric_limits<int64>::max() >> 2));
for (const auto& port : fanout) {
if (port.node == node) {
@@ -843,8 +861,9 @@ static bool IsSwappable(GraphView::InputPort input) {
return !IsRefType(dtype);
}
-static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
- std::unordered_set<string>* skip_list) {
+static bool IdentifySwappingCandidates(
+ Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
+ std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
GraphMemory memory(*item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
@@ -898,10 +917,9 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
// Don't bother with small tensors.
continue;
}
- // Don't try to swap out persistent data
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
- if (!IsSwappable(port)) {
+ if (!IsSwappable(graph, port)) {
continue;
}
Costs::NanoSeconds execution_time(-1);
@@ -943,9 +961,8 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
}
}
if (!found) {
- AttrValue& val =
- (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
- val.mutable_list()->add_i(fanout_to_swap.port_id);
+ (*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
+ fanout_to_swap.port_id);
required_savings -= live_tensor.memory_used;
updated_graph = true;
if (required_savings < 0) {
@@ -961,14 +978,13 @@ static bool IdentifySwappingCandidates(Cluster* cluster, GrapplerItem* item,
bool SwappingPass(RewriterConfig::MemOptType optimization_level,
Cluster* cluster, GrapplerItem* item,
std::unordered_set<string>* skip_list) {
- bool updated_graph = false;
+ std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS ||
optimization_level == RewriterConfig::HEURISTICS) {
// Use heuristics to figure out what needs to be swapped;
- updated_graph = IdentifySwappingCandidates(cluster, item, skip_list);
+ IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap);
}
// Look for manual annotatations in the graph.
- std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
for (auto& node : *item->graph.mutable_node()) {
if (node.attr().count("_swap_to_host") != 0) {
SwapInfo& swap_info = nodes_to_swap[&node];
@@ -1018,10 +1034,11 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
}
GraphView view(&item->graph);
+ bool updated_graph = false;
+
for (auto& swap : nodes_to_swap) {
NodeDef* node = swap.first;
const SwapInfo& swap_info = swap.second;
-
if (skip_list->find(node->name()) != skip_list->end()) {
continue;
}
@@ -1047,7 +1064,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
skip_list->insert(input_name);
}
- // Make sure the tensor isn't swapped out quickly look for node that
+ // Make sure the tensor is swapped out quickly: look for node that
// will execute just after the tensor is generated and add a control
// dependency from the swap out node to that node.
NodeDef* out_trigger =
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index f507178bce..dd2d20d8d6 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -271,18 +271,23 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
{128, 128, 8}, DT_FLOAT);
- Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
- Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
- Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+ Output a = ops::Identity(s.WithOpName("a").WithDevice("/gpu:0"), v);
+ Output b = ops::Square(s.WithOpName("b").WithDevice("/gpu:0"), v);
+ Output c = ops::Sqrt(s.WithOpName("c").WithDevice("/gpu:0"), a);
+ Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), b);
Output axis = ops::Const(s.WithOpName("axis"), 0);
Output e =
- ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+ ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {a, b, c, d}, axis);
+ Output f = ops::Square(s.WithOpName("f").WithDevice("/gpu:0"), a);
+ Output g = ops::Sqrt(s.WithOpName("g").WithDevice("/gpu:0"), b);
+ Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c);
+ Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch = {"e"};
+ item.fetch = {"e", "f", "g", "h", "i"};
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
@@ -293,28 +298,27 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
for (const auto& node : output.node()) {
if (node.name() == "e") {
- EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
- const AttrValue& val = node.attr().at("_swap_to_host");
- EXPECT_TRUE(val.has_list());
- std::set<int> inputs_to_swap;
- for (int64 input_id : val.list().i()) {
- inputs_to_swap.insert(input_id);
- }
- EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+ EXPECT_EQ(5, node.input_size());
+ EXPECT_EQ("a", node.input(0));
+ EXPECT_EQ("swap_in_e_1", node.input(1));
+ EXPECT_EQ("swap_in_e_2", node.input(2));
+ EXPECT_EQ("swap_in_e_3", node.input(3));
+ EXPECT_EQ("axis", node.input(4));
}
}
}
TEST_F(MemoryOptimizerTest, UnswappableInputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
{128, 128, 8}, DT_FLOAT);
+ Output a = ops::Square(s.WithOpName("a").WithDevice("/gpu:0"), v);
Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
Output index = ops::Const(s.WithOpName("index"), {0});
Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
Output d =
- ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), a, indices, c);
+ ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), v, indices, c);
Output axis = ops::Const(s.WithOpName("axis"), 0);
Output e =
ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
@@ -331,9 +335,10 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
TF_EXPECT_OK(status);
for (const auto& node : output.node()) {
- if (node.name() == "d") {
- EXPECT_EQ(1, node.attr().count("_swap_to_host"));
- EXPECT_EQ(2, node.attr().at("_swap_to_host").list().i(0));
+ if (node.name() == "e") {
+ // The d node isn't swappable.
+ EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ("d", node.input(2));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/static_schedule.h b/tensorflow/core/grappler/optimizers/static_schedule.h
index aa2726a2bd..678b4d193f 100644
--- a/tensorflow/core/grappler/optimizers/static_schedule.h
+++ b/tensorflow/core/grappler/optimizers/static_schedule.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
#include <unordered_map>
@@ -47,4 +47,4 @@ Status EstimateRequiredTimes(
} // namespace grappler
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_
diff --git a/tensorflow/core/grappler/utils/frame.h b/tensorflow/core/grappler/utils/frame.h
index be726ae795..95b72748f4 100644
--- a/tensorflow/core/grappler/utils/frame.h
+++ b/tensorflow/core/grappler/utils/frame.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
@@ -40,4 +40,4 @@ Status IdentifyFramesWithNodeMap(const GraphDef& graph, const NodeMap& node_map,
} // namespace grappler
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4e46169971..4fb7aab647 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
@@ -43,4 +43,4 @@ int IdentifyLoops(const GraphDef& graph,
} // namespace grappler
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_
diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h
index f2c9bbfa4e..7700fe41e4 100644
--- a/tensorflow/core/grappler/utils/topological_sort.h
+++ b/tensorflow/core/grappler/utils/topological_sort.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -28,4 +28,4 @@ Status TopologicalSort(GraphDef* graph);
} // namespace grappler
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f40074f1af..fd99409c9b 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -369,6 +369,22 @@ cc_library(
],
)
+cc_library(
+ name = "batch_kernels",
+ srcs = ["batch_kernels.cc"],
+ deps = [
+ "//tensorflow/core:batch_ops_op_lib",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:concat_lib_hdrs",
+ "//tensorflow/core/kernels:ops_util_hdrs",
+ "//tensorflow/core/kernels:split_lib_hdrs",
+ "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
+ "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs",
+ ],
+ alwayslink = 1,
+)
+
tf_kernel_library(
name = "record_input_op",
srcs = [
@@ -4268,7 +4284,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/fft2d:fft2d_headers",
- "@fft2d//:fft2d",
+ "@fft2d",
],
)
@@ -4975,6 +4991,7 @@ filegroup(
"debug_ops.*",
"scatter_nd_op*",
"critical_section.*",
+ "batch_kernels.*",
],
),
visibility = ["//visibility:public"],
@@ -5007,8 +5024,8 @@ cc_library(
"//tensorflow/core:protos_all_cc_impl",
"//third_party/eigen3",
"//third_party/fft2d:fft2d_headers",
- "@fft2d//:fft2d",
- "@gemmlowp//:gemmlowp",
+ "@fft2d",
+ "@gemmlowp",
"@protobuf_archive//:protobuf",
],
alwayslink = 1,
@@ -5079,7 +5096,7 @@ tf_kernel_library(
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
],
)
@@ -5840,9 +5857,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn//:mkl_dnn",
- ],
+ ]),
)
tf_mkl_kernel_library(
@@ -6008,6 +6026,6 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
],
)
diff --git a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
index c160ce2c33..49df5ae296 100644
--- a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
+++ b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
#if GOOGLE_CUDA
@@ -143,4 +143,4 @@ __global__ void adjust_hsv_nhwc(const int64 number_elements,
} // namespace tensorflow
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 6041d8c9b2..5b4e1a809f 100644
--- a/tensorflow/contrib/batching/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -13,20 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/batching/shared_batch_scheduler.h"
-#include "tensorflow/contrib/batching/util/periodic_function.h"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
+#include "tensorflow/core/kernels/batching_util/periodic_function.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/split_lib.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/macros.h"
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
diff --git a/tensorflow/core/kernels/batch_util.h b/tensorflow/core/kernels/batch_util.h
index b066e2a574..0d634ae7b0 100644
--- a/tensorflow/core/kernels/batch_util.h
+++ b/tensorflow/core/kernels/batch_util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -35,4 +35,4 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
} // namespace batch_util
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_
diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
index ff8ebb349f..25c5f9cf42 100644
--- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <algorithm>
#include <functional>
@@ -657,4 +656,4 @@ size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h
index 9207972100..2b5a991caf 100644
--- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
#include <stddef.h>
#include <cstddef>
@@ -265,4 +265,4 @@ BasicBatchScheduler<TaskType>::BasicBatchScheduler(
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h
index a5316f152b..f6d9a8f0c8 100644
--- a/tensorflow/core/kernels/batching_util/batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h
@@ -23,8 +23,8 @@ limitations under the License.
//
// This file defines an abstract BatchScheduler class.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
#include <stddef.h>
#include <algorithm>
@@ -278,4 +278,4 @@ void Batch<TaskType>::Close() {
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/batching_util/fake_clock_env.h b/tensorflow/core/kernels/batching_util/fake_clock_env.h
index b2848afe07..60f1cbe7bd 100644
--- a/tensorflow/core/kernels/batching_util/fake_clock_env.h
+++ b/tensorflow/core/kernels/batching_util/fake_clock_env.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
#include <functional>
#include <string>
@@ -73,4 +73,4 @@ class FakeClockEnv : public EnvWrapper {
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_
diff --git a/tensorflow/core/kernels/batching_util/periodic_function.h b/tensorflow/core/kernels/batching_util/periodic_function.h
index 6811cd015e..dbf1733dcc 100644
--- a/tensorflow/core/kernels/batching_util/periodic_function.h
+++ b/tensorflow/core/kernels/batching_util/periodic_function.h
@@ -49,9 +49,8 @@ limitations under the License.
// PeriodicFunction periodic_function_;
// };
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -132,4 +131,4 @@ class PeriodicFunction {
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
index 3736d8ef64..b77289aded 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
#include <stddef.h>
#include <deque>
@@ -702,4 +702,4 @@ size_t QueueHandle<TaskType>::SchedulingCapacity() const {
} // namespace serving
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 42f3db1d79..2ca194a77f 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -173,19 +173,13 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
- if (thread_index < 16) {
- s_data[thread_index] += s_data[thread_index + 16];
- __syncwarp(0xFFFF);
- if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
- __syncwarp(0xFF);
- if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
- __syncwarp(0xF);
- if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
- __syncwarp(0x3);
+ if (thread_index < 32) {
+ AccT data = s_data[thread_index];
+ for (int32 delta = warpSize / 2; delta > 0; delta /= 2) {
+ data += CudaShuffleXorSync(kCudaWarpAll, data, delta);
+ }
if (thread_index == 0) {
- T val = T(s_data[0] + s_data[1]);
- // The first thread writes out the accumulated result to global location.
- CudaAtomicAdd(bias_backprop + bias_index, val);
+ CudaAtomicAdd(bias_backprop + bias_index, T(data));
}
}
}
diff --git a/tensorflow/core/kernels/bitcast_op.h b/tensorflow/core/kernels/bitcast_op.h
index 0413569e79..900ab6f35c 100644
--- a/tensorflow/core/kernels/bitcast_op.h
+++ b/tensorflow/core/kernels/bitcast_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
#include <string.h> // for memcpy
@@ -27,4 +27,4 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/casts.h"
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h
index cdf191f4c7..2d2d87134e 100644
--- a/tensorflow/core/kernels/captured_function.h
+++ b/tensorflow/core/kernels/captured_function.h
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
+#define TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
#include "tensorflow/core/kernels/data/captured_function.h"
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index 6309e4a4dc..470e9e0804 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
#define EIGEN_USE_THREADS
@@ -181,4 +181,4 @@ GetSyclCastFromDouble(DataType dst_dtype);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.h b/tensorflow/core/kernels/compare_and_bitpack_op.h
index 8e020249c1..af8566c7ce 100644
--- a/tensorflow/core/kernels/compare_and_bitpack_op.h
+++ b/tensorflow/core/kernels/compare_and_bitpack_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
+#define TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -39,4 +39,4 @@ struct CompareAndBitpack {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_COMPARE_AND_BITPACK_OP_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index 27db6ee785..794ac6fa6d 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -161,21 +161,21 @@ class ConditionalAccumulatorBase : public ResourceBase {
* The below macros return a boolean if the test fails, so that the calling
* function can get an indication that a failure has occurred.
*/
-#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \
- do { \
- if (!TF_PREDICT_TRUE(EXP)) { \
- (CTX)->CtxFailure((STATUS)); \
- return false; \
- } \
+#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) { \
+ (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
+ return false; \
+ } \
} while (0)
-#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \
- do { \
- ::tensorflow::Status _s(STATUS); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(_s); \
- return false; \
- } \
+#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return false; \
+ } \
} while (0)
/*
diff --git a/tensorflow/core/kernels/cuda_device_array.h b/tensorflow/core/kernels/cuda_device_array.h
index a570993cf8..e7a5db0683 100644
--- a/tensorflow/core/kernels/cuda_device_array.h
+++ b/tensorflow/core/kernels/cuda_device_array.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
+#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#if GOOGLE_CUDA
@@ -117,4 +117,4 @@ class CudaDeviceArrayOnHost {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
+#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
diff --git a/tensorflow/core/kernels/cuda_device_array_gpu.h b/tensorflow/core/kernels/cuda_device_array_gpu.h
index 220f762636..64fa3cb806 100644
--- a/tensorflow/core/kernels/cuda_device_array_gpu.h
+++ b/tensorflow/core/kernels/cuda_device_array_gpu.h
@@ -15,8 +15,8 @@ limitations under the License.
// Contains structs and functions to be included in device code.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#if GOOGLE_CUDA
@@ -47,4 +47,4 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetCudaDeviceArrayOnDevice(
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 3c389a82ab..ecfa23750c 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -427,7 +427,7 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
int64 size, const string& debug_info) {
DeviceLapackInfo new_dev_info(context_, size, debug_info);
scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
- return std::move(new_dev_info);
+ return new_dev_info;
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc
index 5fb0735ac1..cf86478b0f 100644
--- a/tensorflow/core/kernels/cwise_op_pow.cc
+++ b/tensorflow/core/kernels/cwise_op_pow.cc
@@ -16,8 +16,9 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32,
- int64, complex64, complex128);
+REGISTER5(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double,
+ complex64, complex128);
+REGISTER2(BinaryOp, CPU, "Pow", functor::safe_pow, int32, int64);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double,
@@ -25,5 +26,5 @@ REGISTER4(BinaryOp, GPU, "Pow", functor::pow, float, Eigen::half, double,
#endif
#ifdef TENSORFLOW_USE_SYCL
REGISTER2(BinaryOp, SYCL, "Pow", functor::pow, float, double);
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index da70b1e314..06918075a4 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <type_traits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -115,6 +116,35 @@ struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
};
+template <typename Scalar, typename Exponent>
+struct safe_scalar_binary_pow_op {
+ static_assert(std::is_integral<Scalar>::value, "Integer type expected");
+ static_assert(std::is_integral<Exponent>::value &&
+ std::is_signed<Exponent>::value,
+ "Signed integer type expected");
+
+ bool* const error;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
+ : error(error) {}
+
+ EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
+ const Exponent& b) const {
+ const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
+ if (TF_PREDICT_TRUE(safe_b >= 0)) {
+ return numext::pow(a, safe_b);
+ } else {
+ *error = true;
+ return 0;
+ }
+ }
+};
+
+template <typename Scalar, typename Exponent>
+struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
+ enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
template <typename T, typename DivOrMod>
struct safe_div_or_mod_op {
static_assert(std::is_integral<T>::value, "Integer type expected");
@@ -742,6 +772,11 @@ template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
template <typename T>
+struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
+ static const bool has_errors = true;
+};
+
+template <typename T>
struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 693c6467ac..e561e59cf5 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -40,6 +40,11 @@ void BinaryOpShared::SetComputeError(OpKernelContext* ctx) {
if ((op == "Div" || op == "Mod" || op == "FloorMod" || op == "FloorDiv") &&
DataTypeIsInteger(ctx->op_kernel().input_type(0))) {
ctx->CtxFailure(errors::InvalidArgument("Integer division by zero"));
+ } else if ((op == "Pow") &&
+ DataTypeIsInteger(ctx->op_kernel().input_type(0)) &&
+ DataTypeIsSigned(ctx->op_kernel().input_type(1))) {
+ ctx->CtxFailure(errors::InvalidArgument(
+ "Integers to negative integer powers are not allowed"));
} else {
ctx->CtxFailure(
errors::Internal("Unexpected error in binary operator "
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 99e0ef426e..32d2bc3aae 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
#include <memory>
#include <vector>
@@ -105,4 +105,4 @@ class CapturedFunction {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/dataset.h b/tensorflow/core/kernels/data/dataset.h
index 3cb3c08a32..2ef31ddfaa 100644
--- a/tensorflow/core/kernels/data/dataset.h
+++ b/tensorflow/core/kernels/data/dataset.h
@@ -12,18 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
#include <memory>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -193,9 +195,17 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
+ // Returns whether an op has been whitelisted for use inside map_fns.
+ // Uses a heuristic to whitelist source dataset ops which have been
+ // marked stateful due to b/65524810.
+ // Also looks up the `op_def->name` in the global
+ // `WhitelistedStatefulOpRegistry`.
bool IsOpWhitelisted(const OpDef* op_def) const {
- return StringPiece(op_def->name()).ends_with("Dataset") &&
- HasAttr(op_def, "output_shapes");
+ return (StringPiece(op_def->name()).ends_with("Dataset") &&
+ op_def->output_arg_size() == 1 &&
+ op_def->output_arg(0).type() == DT_VARIANT) ||
+ dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
+ op_def->name());
}
bool HasAttr(const string& op_type_name, const string& attr_name) const;
@@ -596,4 +606,4 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_H_
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 40bc873584..6c4191c2be 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
@@ -32,4 +32,4 @@ Status MakeIteratorFromInputElement(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 244df137cb..56044a3d41 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -829,8 +829,8 @@ class IteratorGetNextOp : public AsyncOpKernel {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
@@ -870,6 +870,39 @@ class IteratorGetNextOp : public AsyncOpKernel {
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
+class IteratorGetNextSyncOp : public OpKernel {
+ public:
+ explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
+ core::ScopedUnref unref_iterator(iterator);
+
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.stats_aggregator_getter = [iterator]() {
+ return iterator->stats_aggregator();
+ };
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ IteratorContext iter_ctx(std::move(params));
+
+ OP_REQUIRES_OK(ctx,
+ iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
+ OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
+
+ for (int i = 0; i < components.size(); ++i) {
+ // TODO(mrry): Check that the shapes match the shape attrs.
+ ctx->set_output(i, components[i]);
+ }
+ }
+};
+
class IteratorToStringHandleOp : public OpKernel {
public:
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
@@ -1033,6 +1066,8 @@ REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
IteratorGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
+ IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 01f9b9fa09..89360d1cd9 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -95,10 +95,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
- DataTypeVector other_arguments_types(
- captured_func_->captured_inputs().size());
- std::vector<Node*> other_arguments(
- captured_func_->captured_inputs().size());
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index f09871d98d..bc4426a9fd 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -109,10 +109,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
// Input: other_arguments
- DataTypeVector other_arguments_types(
- captured_func_->captured_inputs().size());
- std::vector<Node*> other_arguments(
- captured_func_->captured_inputs().size());
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h
index 0d0c38eb58..a34691b5a2 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/sql/driver_manager.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
#include "tensorflow/core/kernels/data/sql/query_connection.h"
@@ -38,4 +38,4 @@ class DriverManager {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h
index 1947148972..f31017bd19 100644
--- a/tensorflow/core/kernels/data/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/sql/query_connection.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
#include "tensorflow/core/framework/tensor.h"
@@ -64,4 +64,4 @@ class QueryConnection {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
index b36b69eae4..787c17d6c0 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
#include <memory>
@@ -53,4 +53,4 @@ class SqliteQueryConnection : public QueryConnection {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/stats_aggregator.h b/tensorflow/core/kernels/data/stats_aggregator.h
index 4cb8dba5cb..076a56b0bf 100644
--- a/tensorflow/core/kernels/data/stats_aggregator.h
+++ b/tensorflow/core/kernels/data/stats_aggregator.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
#include <memory>
#include <string>
@@ -81,4 +81,4 @@ class StatsAggregatorResource : public ResourceBase {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_STATS_AGGREGATOR_H_
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 25396bd3e7..97c31668ac 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
#include <vector>
@@ -45,4 +45,4 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index 2aa6dbe6f3..69ab78d635 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATASET_H_
#include "tensorflow/core/kernels/data/dataset.h"
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATASET_H_
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index c778278e8f..b7d120a617 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -39,6 +39,13 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument("channels must be 0, 1, 3 or 4, got ",
channels_));
}
+ inline int32 ByteSwapInt32ForBigEndian(int32 x) {
+#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+ return le32toh(x);
+#else
+ return x;
+#endif
+ }
void Compute(OpKernelContext* context) override {
const Tensor& contents = context->input(0);
@@ -56,14 +63,18 @@ class DecodeBmpOp : public OpKernel {
input.size(), " bytes"));
const uint8* img_bytes = reinterpret_cast<const uint8*>(input.data());
- const int32 header_size = internal::SubtleMustCopy(
+ int32 header_size_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 10)));
- const int32 width = internal::SubtleMustCopy(
+ const int32 header_size = ByteSwapInt32ForBigEndian(header_size_);
+ int32 width_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 18)));
- const int32 height = internal::SubtleMustCopy(
+ const int32 width = ByteSwapInt32ForBigEndian(width_);
+ int32 height_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 22)));
- const int32 bpp = internal::SubtleMustCopy(
+ const int32 height = ByteSwapInt32ForBigEndian(height_);
+ int32 bpp_ = internal::SubtleMustCopy(
*(reinterpret_cast<const int32*>(img_bytes + 28)));
+ const int32 bpp = ByteSwapInt32ForBigEndian(bpp_);
if (channels_) {
OP_REQUIRES(context, (channels_ == bpp / 8),
diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc
index ceb152c3f0..44dcbf834c 100644
--- a/tensorflow/core/kernels/decode_image_op.cc
+++ b/tensorflow/core/kernels/decode_image_op.cc
@@ -87,11 +87,10 @@ class DecodeImageOp : public OpKernel {
channels_ = 3;
} else {
OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
- OP_REQUIRES(
- context,
- channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
- errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ",
- channels_));
+ OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 ||
+ channels_ == 4,
+ errors::InvalidArgument(
+ "channels must be 0, 1, 3, or 4, got ", channels_));
}
flags_.components = channels_;
@@ -115,9 +114,8 @@ class DecodeImageOp : public OpKernel {
if (format_ == kJpgFormat) {
OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio));
- OP_REQUIRES(context,
- flags_.ratio == 1 || flags_.ratio == 2 || flags_.ratio == 4 ||
- flags_.ratio == 8,
+ OP_REQUIRES(context, flags_.ratio == 1 || flags_.ratio == 2 ||
+ flags_.ratio == 4 || flags_.ratio == 8,
errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ",
flags_.ratio));
OP_REQUIRES_OK(context, context->GetAttr("fancy_upscaling",
@@ -132,9 +130,8 @@ class DecodeImageOp : public OpKernel {
string dct_method;
OP_REQUIRES_OK(context, context->GetAttr("dct_method", &dct_method));
OP_REQUIRES(
- context,
- (dct_method.empty() || dct_method == "INTEGER_FAST" ||
- dct_method == "INTEGER_ACCURATE"),
+ context, (dct_method.empty() || dct_method == "INTEGER_FAST" ||
+ dct_method == "INTEGER_ACCURATE"),
errors::InvalidArgument("dct_method must be one of "
"{'', 'INTEGER_FAST', 'INTEGER_ACCURATE'}"));
if (dct_method == "INTEGER_FAST") {
@@ -160,9 +157,9 @@ class DecodeImageOp : public OpKernel {
errors::InvalidArgument("Expected image (JPEG, PNG, or GIF), got ",
FileFormatString(magic, input)));
OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
- errors::InvalidArgument(
- FileFormatString(magic, input),
- " contents are too large for int: ", input.size()));
+ errors::InvalidArgument(FileFormatString(magic, input),
+ " contents are too large for int: ",
+ input.size()));
OP_REQUIRES(context, magic == kPngFormat || channel_bits_ == 8,
errors::InvalidArgument(FileFormatString(magic, input),
" does not support uint16 output"));
@@ -215,10 +212,9 @@ class DecodeImageOp : public OpKernel {
input.data(), input.size(), flags, nullptr /* nwarn */,
[=, &output](int width, int height, int channels) -> uint8* {
Status status(context->allocate_output(
- 0,
- format_ == kGifFormat
- ? TensorShape({1, height, width, channels})
- : TensorShape({height, width, channels}),
+ 0, format_ == kGifFormat
+ ? TensorShape({1, height, width, channels})
+ : TensorShape({height, width, channels}),
&output));
if (!status.ok()) {
VLOG(1) << status;
@@ -294,6 +290,7 @@ class DecodeImageOp : public OpKernel {
// Decode GIF, allocating tensor once the size is known.
Tensor* output = nullptr;
+ string error_string;
OP_REQUIRES(
context,
gif::Decode(input.data(), input.size(),
@@ -320,8 +317,10 @@ class DecodeImageOp : public OpKernel {
return nullptr;
}
return output->flat<uint8>().data();
- }),
- errors::InvalidArgument("Invalid GIF data, size ", input.size()));
+ },
+ &error_string),
+ errors::InvalidArgument("Invalid GIF data (size ", input.size(), "), ",
+ error_string));
}
private:
diff --git a/tensorflow/core/kernels/deep_conv2d.h b/tensorflow/core/kernels/deep_conv2d.h
index c3f6f66dc9..17a0230516 100644
--- a/tensorflow/core/kernels/deep_conv2d.h
+++ b/tensorflow/core/kernels/deep_conv2d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
+#define TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
#include "tensorflow/core/framework/types.h"
@@ -114,4 +114,4 @@ struct DeepConv2D {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
+#endif // TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_
diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h
index 097a9f5bfa..ba262d56ee 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.h
+++ b/tensorflow/core/kernels/depthwise_conv_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/types.h"
@@ -284,4 +284,4 @@ struct DepthwiseInputCopyOp {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 903aac5d68..5493e33532 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -34,6 +34,7 @@ limitations under the License.
namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
using Eigen::GpuDevice;
// Returns whether depthwise convolution forward or backward input pass can be
@@ -1028,7 +1029,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
int zeros = sub_warp * kWidth;
unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
- val += CudaShuffleXor(mask, val, delta);
+ val += CudaShuffleXorSync(mask, val, delta);
}
return val;
}
@@ -1145,7 +1146,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1159,7 +1160,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleXorSync(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@@ -1399,7 +1400,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1413,10 +1414,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleXorSync(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
- *accum_ptr = val;
+ *accum_ptr = val; // kBlockSlices threads per warp.
}
++shared_offset;
accum_ptr += accum_increment;
diff --git a/tensorflow/core/kernels/determinant_op.h b/tensorflow/core/kernels/determinant_op.h
index e931e328e4..eefdfe0ae4 100644
--- a/tensorflow/core/kernels/determinant_op.h
+++ b/tensorflow/core/kernels/determinant_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
@@ -44,4 +44,4 @@ struct LogDeterminantFromPivotedLUFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
diff --git a/tensorflow/core/kernels/eigen_activations.h b/tensorflow/core/kernels/eigen_activations.h
index 57c8157b87..99b4b2abe6 100644
--- a/tensorflow/core/kernels/eigen_activations.h
+++ b/tensorflow/core/kernels/eigen_activations.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -122,4 +122,4 @@ struct functor_traits<scalar_clip_op<Scalar> > {
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_
diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h
index f4c42372b1..3a94b8c993 100644
--- a/tensorflow/core/kernels/eigen_attention.h
+++ b/tensorflow/core/kernels/eigen_attention.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -239,4 +239,4 @@ ExtractGlimpses(const Input& input,
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index a44e7197a9..e13e548f86 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/eigen_volume_patch.h"
@@ -617,4 +617,4 @@ CuboidConvolutionBackwardKernel(
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index d172de8e18..aec7697810 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 2dca664a86..62e9f9123d 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/eigen_volume_patch.h"
@@ -224,4 +224,4 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h
index 94100d71ec..972036833f 100644
--- a/tensorflow/core/kernels/eigen_pooling.h
+++ b/tensorflow/core/kernels/eigen_pooling.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/eigen_volume_patch.h"
@@ -610,4 +610,4 @@ CuboidAvgPooling(const Input& input, DenseIndex patchPlanes,
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_
diff --git a/tensorflow/core/kernels/eigen_softmax.h b/tensorflow/core/kernels/eigen_softmax.h
index 20bb8a44dd..a2930a726f 100644
--- a/tensorflow/core/kernels/eigen_softmax.h
+++ b/tensorflow/core/kernels/eigen_softmax.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -87,4 +87,4 @@ SoftMax(const Input& input, const float beta)
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SOFTMAX_H_
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index 7702f3e70a..2fe64cd72a 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -1069,4 +1069,4 @@ EIGEN_DEVICE_FUNC
} // end namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index afd5f37e35..a3d795813d 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -653,4 +653,4 @@ OVERRIDE_EVALUATOR(Eigen::DefaultDevice);
}; // namespace Eigen
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
diff --git a/tensorflow/core/kernels/eye_functor.h b/tensorflow/core/kernels/eye_functor.h
index 70f093f813..3799cfba9a 100644
--- a/tensorflow/core/kernels/eye_functor.h
+++ b/tensorflow/core/kernels/eye_functor.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
#include "tensorflow/core/framework/tensor_types.h"
@@ -29,4 +29,4 @@ struct EyeFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index 7aaad6e6c7..81189866c3 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
#include <tuple>
@@ -277,4 +277,4 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h
index a50b51b54b..11ea63d730 100644
--- a/tensorflow/core/kernels/gather_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
#if GOOGLE_CUDA
@@ -118,4 +118,4 @@ struct GatherFunctor<GPUDevice, T, Index> {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index 366877bcf5..ffc733e6bb 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#if GOOGLE_CUDA
@@ -162,4 +162,4 @@ class AutoTuneSingleton {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 125d1fd200..a360d188cc 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
#include <array>
#include <unordered_map>
@@ -225,4 +225,4 @@ class GraphTransferer {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
index 8eb3995fc4..dca1f94a9b 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
#include <unordered_map>
#include <vector>
@@ -88,4 +88,4 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
index 993a5f9a3a..b9328c8e0e 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
#include <unordered_map>
@@ -55,4 +55,4 @@ class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H
diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
index 05b76172b2..eb6b64da58 100644
--- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h
+++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.h"
@@ -72,4 +72,4 @@ class IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h b/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h
index 7d3329f490..9e51c9f51f 100644
--- a/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h
+++ b/tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
+#define TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
@@ -43,4 +43,4 @@ class IRemoteFusedGraphOpsDefinitions {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
+#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_OPS_DEFINITIONS_H_
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 5e405f16a4..baf0a4abe4 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -87,6 +87,14 @@ REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
+Status TensorListShape(const TensorList& t, TensorShape* s) {
+ *s = TensorShape({});
+ return Status::OK();
+}
+
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
+ TensorListShape);
+
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
string metadata;
@@ -251,6 +259,45 @@ REGISTER_KERNEL_BUILDER(
#endif // GOOGLE_CUDA
+class TensorListElementShape : public OpKernel {
+ public:
+ explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ OP_REQUIRES(
+ c, c->input(0).shape().num_elements() == 1,
+ errors::InvalidArgument("List tensors are supposed to be scalars."));
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "TensorListElementShape received a variant which is not a "
+ "list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ Tensor* result;
+ OP_REQUIRES_OK(c, c->allocate_output(
+ 0, TensorShape{l->element_shape.dims()}, &result));
+ for (int i = 0; i < l->element_shape.dims(); ++i) {
+ if (result->dtype() == DT_INT32) {
+ result->flat<int32>()(i) = l->element_shape.dim_size(i);
+ } else {
+ result->flat<int64>()(i) = l->element_shape.dim_size(i);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
+ TensorListElementShape);
+
+#if GOOGLE_CUDA
+
+REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
+ .Device(DEVICE_GPU)
+ .HostMemory("element_shape"),
+ TensorListElementShape);
+
+#endif // GOOGLE_CUDA
+
class TensorListPopBack : public OpKernel {
public:
explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
@@ -299,6 +346,134 @@ REGISTER_KERNEL_BUILDER(Name("TensorListPopBack").Device(DEVICE_GPU),
#endif // GOOGLE_CUDA
+class TensorListReserve : public OpKernel {
+ public:
+ explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
+ int32 num_elements = c->input(1).scalar<int32>()();
+ TensorList output;
+ output.element_shape = element_shape;
+ output.element_dtype = element_dtype_;
+ output.tensors.resize(num_elements, Tensor(DT_INVALID));
+ Tensor* result;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
+ result->scalar<Variant>()() = std::move(output);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
+ TensorListReserve);
+
+#if GOOGLE_CUDA
+
+REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
+ .Device(DEVICE_GPU)
+ .HostMemory("element_shape")
+ .HostMemory("num_elements"),
+ TensorListReserve);
+
+#endif // GOOGLE_CUDA
+
+class TensorListGetItem : public OpKernel {
+ public:
+ explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ OP_REQUIRES(
+ c, c->input(0).shape().num_elements() == 1,
+ errors::InvalidArgument("List tensors are supposed to be scalars."));
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ int32 index = c->input(1).scalar<int32>()();
+ OP_REQUIRES(c, index < l->tensors.size(),
+ errors::InvalidArgument("Trying to access element ", index,
+ " in a list with ", l->tensors.size(),
+ " elements."));
+ c->set_output(0, l->tensors[index]);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TensorListGetItem").Device(DEVICE_CPU),
+ TensorListGetItem);
+
+#if GOOGLE_CUDA
+
+REGISTER_KERNEL_BUILDER(
+ Name("TensorListGetItem").Device(DEVICE_GPU).HostMemory("index"),
+ TensorListGetItem);
+
+#endif // GOOGLE_CUDA
+
+class TensorListSetItem : public OpKernel {
+ public:
+ explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ int32 index = c->input(1).scalar<int32>()();
+ OP_REQUIRES(c, index < l->tensors.size(),
+ errors::InvalidArgument("Trying to modify element ", index,
+ " in a list with ", l->tensors.size(),
+ " elements."));
+ TensorList output;
+ output = *l;
+ output.tensors[index] = c->input(2);
+ Tensor* result;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
+ result->scalar<Variant>()() = std::move(output);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
+ TensorListSetItem);
+
+#if GOOGLE_CUDA
+
+REGISTER_KERNEL_BUILDER(
+ Name("TensorListSetItem").Device(DEVICE_GPU).HostMemory("index"),
+ TensorListSetItem);
+
+#endif // GOOGLE_CUDA
+
#define REGISTER_TENSOR_LIST_STACK_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 6a2a572b6d..9733883001 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
+#define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
@@ -76,14 +76,14 @@ class TensorListStack : public OpKernel {
errors::InvalidArgument(
"Input handle is not a list. Saw: '",
c->input(0).scalar<Variant>()().DebugString(), "'"));
- OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
- errors::InvalidArgument("Tried to stack elements from a list "
- "with non-fully-defined shape."));
OP_REQUIRES(c, element_dtype_ == l->element_dtype,
errors::InvalidArgument("Invalid data types; op elements ",
DataTypeString(element_dtype_),
" but list elements ",
DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape."));
if (num_elements_ != -1) {
OP_REQUIRES(c, l->tensors.size() == num_elements_,
errors::InvalidArgument("Operation expected a list with ",
@@ -98,16 +98,23 @@ class TensorListStack : public OpKernel {
}
Tensor* output;
OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
ConstMatrixVector inputs_flat;
inputs_flat.reserve(l->tensors.size());
for (const auto& t : l->tensors) {
+ OP_REQUIRES(
+ c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
t.shaped<T, 2>({1, t.NumElements()})));
}
- auto output_flat =
- output->shaped<T, 2>({1, static_cast<int64>(l->tensors.size()) *
- l->element_shape.num_elements()});
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
#if GOOGLE_CUDA
if (std::is_same<Device, Eigen::GpuDevice>::value) {
@@ -195,17 +202,26 @@ Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
for (int i = 0; i < a.tensors.size(); ++i) {
const Tensor& a_tensor = a.tensors[i];
const Tensor& b_tensor = b.tensors[i];
+ if (a_tensor.dtype() == DT_INVALID) {
+ out->tensors.push_back(b_tensor);
+ continue;
+ }
+ if (b_tensor.dtype() == DT_INVALID) {
+ out->tensors.push_back(a_tensor);
+ continue;
+ }
if (a_tensor.shape() != b_tensor.shape()) {
// TODO(apassos) support broadcasting additions here?
return errors::InvalidArgument(
"Trying to add two tensors with incompatible element shapes. "
"One is ",
a_tensor.shape().DebugString(), " and the other is ",
- b_tensor.shape().DebugString());
+ b_tensor.shape().DebugString(), " in position ", i);
}
Tensor out_tensor;
TF_RETURN_IF_ERROR(
c->allocate_temp(a_tensor.dtype(), a_tensor.shape(), &out_tensor));
+ out->tensors.push_back(out_tensor);
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
@@ -254,4 +270,4 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
+#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
diff --git a/tensorflow/core/kernels/meta_support.h b/tensorflow/core/kernels/meta_support.h
index 53aece78e8..97f39eb598 100644
--- a/tensorflow/core/kernels/meta_support.h
+++ b/tensorflow/core/kernels/meta_support.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
-#define THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
+#ifndef TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
+#define TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
#include "meta/multi_thread_gemm.h"
#include "meta/multi_thread_transform.h"
@@ -109,4 +109,4 @@ void Clamp(OpKernelContext* context, const quint8* input, int input_count,
} // namespace meta
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
+#endif // TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
diff --git a/tensorflow/core/kernels/mfcc.h b/tensorflow/core/kernels/mfcc.h
index 0d5d9fb90f..8268f47203 100644
--- a/tensorflow/core/kernels/mfcc.h
+++ b/tensorflow/core/kernels/mfcc.h
@@ -15,8 +15,8 @@ limitations under the License.
// Basic class for computing MFCCs from spectrogram slices.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MFCC_H_
+#define TENSORFLOW_CORE_KERNELS_MFCC_H_
#include <vector>
@@ -74,4 +74,4 @@ class Mfcc {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
+#endif // TENSORFLOW_CORE_KERNELS_MFCC_H_
diff --git a/tensorflow/core/kernels/mfcc_dct.h b/tensorflow/core/kernels/mfcc_dct.h
index 4fa3c01628..888b8e8df8 100644
--- a/tensorflow/core/kernels/mfcc_dct.h
+++ b/tensorflow/core/kernels/mfcc_dct.h
@@ -15,8 +15,8 @@ limitations under the License.
// Basic minimal DCT class for MFCC speech processing.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
+#define TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
#include <vector>
@@ -41,4 +41,4 @@ class MfccDct {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
+#endif // TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.h b/tensorflow/core/kernels/mfcc_mel_filterbank.h
index a766a20cbc..1bdc2dc93b 100644
--- a/tensorflow/core/kernels/mfcc_mel_filterbank.h
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank.h
@@ -15,8 +15,8 @@ limitations under the License.
// Basic class for applying a mel-scale mapping to a power spectrum.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
+#define TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
@@ -63,4 +63,4 @@ class MfccMelFilterbank {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
+#endif // TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
diff --git a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
index bb22b2aa91..6716a26fac 100644
--- a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
+++ b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
#define EIGEN_USE_THREADS
@@ -41,4 +41,4 @@ TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 44b94be3a0..89d37d2f87 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -61,6 +61,16 @@ class MklAddNOp : public OpKernel {
GetMklShape(ctx, src2_idx, &(mkl_context.input2_shape));
bool input2_in_mkl_format = mkl_context.input2_shape.IsMklTensor();
+ // if the shapes of two tensors are not same raise op error
+ TensorShape src1_shape, src2_shape;
+ src1_shape = input0.shape();
+ src2_shape = input1.shape();
+ if (!src1_shape.IsSameSize(src2_shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", this->name(), " of type ",
+ this->type_string(), " must have the same size and shape. Input 0: ",
+ src1_shape.DebugString(), " != input 1: ", src2_shape.DebugString()));
+ }
// handle the case of a scalar
if (!input1_in_mkl_format && input0.dims() == 0) {
const TensorShape& o_shape = input0.shape();
@@ -70,17 +80,16 @@ class MklAddNOp : public OpKernel {
mkl_context.output_shape);
float user_i1 = (input0.scalar<T>()());
float user_i2 = (input1.scalar<T>()());
- out_tensor->scalar<T>()() =
- std::plus<float>{}(user_i1, user_i2);
+ out_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
return;
}
mkl_context.in_dims = input1_in_mkl_format
- ? mkl_context.input1_shape.GetDimension()
- : input0.dims();
+ ? mkl_context.input1_shape.GetDimension()
+ : input0.dims();
mkl_context.in_dims = input2_in_mkl_format
- ? mkl_context.input2_shape.GetDimension()
- : input1.dims();
+ ? mkl_context.input2_shape.GetDimension()
+ : input1.dims();
// If there is nothing to compute, return.
if (!input1_in_mkl_format && !input2_in_mkl_format) {
@@ -89,7 +98,7 @@ class MklAddNOp : public OpKernel {
Tensor* out_tensor = nullptr;
mkl_context.output_shape.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, src1_idx, &out_tensor, o_shape,
- mkl_context.output_shape);
+ mkl_context.output_shape);
return;
}
}
@@ -98,9 +107,9 @@ class MklAddNOp : public OpKernel {
mkl_context.in_strides = new size_t[mkl_context.in_dims];
// Generate size, stride for input if input is in MKL format.
if (input1_in_mkl_format || input2_in_mkl_format) {
- const MklShape* tmp_mkl_shape =
- (input1_in_mkl_format) ? &mkl_context.input1_shape :
- &mkl_context.input2_shape;
+ const MklShape* tmp_mkl_shape = (input1_in_mkl_format)
+ ? &mkl_context.input1_shape
+ : &mkl_context.input2_shape;
for (int i = 0; i < mkl_context.in_dims; i++) {
mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
@@ -124,32 +133,33 @@ class MklAddNOp : public OpKernel {
Tensor mkl_tmp_input1_buf_tensor, mkl_tmp_input2_buf_tensor;
mkl_context.MklPrepareAddNInputs(ctx, &mkl_tmp_input1_buf_tensor,
- &mkl_tmp_input2_buf_tensor);
+ &mkl_tmp_input2_buf_tensor);
Tensor* output = nullptr;
if (input1_in_mkl_format || input2_in_mkl_format) {
- TensorShape tf_shape;
- mkl_context.output_shape.SetMklTensor(true);
- mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise, dnnResourceDst);
-
- mkl_context.output_shape.SetTfLayout(
- mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
- if (input1_in_mkl_format == true) {
- mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
- mkl_context.input1_shape.GetTfToMklDimMap());
- } else {
- mkl_context.output_shape.SetTfDimOrder(mkl_context.in_dims,
- mkl_context.input2_shape.GetTfToMklDimMap());
- }
- tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
- mkl_context.output_shape.GetMklLayout())) /
- sizeof(T));
-
- AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
- mkl_context.output_shape);
+ TensorShape tf_shape;
+ mkl_context.output_shape.SetMklTensor(true);
+ mkl_context.output_shape.SetMklLayout(mkl_context.Eltwise,
+ dnnResourceDst);
+
+ mkl_context.output_shape.SetTfLayout(
+ mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
+ if (input1_in_mkl_format == true) {
+ mkl_context.output_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.input1_shape.GetTfToMklDimMap());
+ } else {
+ mkl_context.output_shape.SetTfDimOrder(
+ mkl_context.in_dims, mkl_context.input2_shape.GetTfToMklDimMap());
+ }
+ tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+ mkl_context.output_shape.GetMklLayout())) /
+ sizeof(T));
+
+ AllocateOutputSetMklShape(ctx, src1_idx, &output, tf_shape,
+ mkl_context.output_shape);
} else {
- const TensorShape& o_shape = input1.shape();
- mkl_context.output_shape.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
+ const TensorShape& o_shape = input1.shape();
+ mkl_context.output_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, src1_idx, &output, o_shape,
mkl_context.output_shape);
}
@@ -177,18 +187,16 @@ class MklAddNOp : public OpKernel {
void MklCreateInputLayouts(OpKernelContext* context) {
bool input1_in_mkl_format = input1_shape.IsMklTensor();
if (!input1_in_mkl_format) {
- CHECK_EQ(
- dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
- E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input1, in_dims, in_sizes, in_strides),
+ E_SUCCESS);
} else {
lt_input1 = static_cast<dnnLayout_t>(input1_shape.GetCurLayout());
}
bool input2_in_mkl_format = input2_shape.IsMklTensor();
if (!input2_in_mkl_format) {
- CHECK_EQ(
- dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
- E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreate_F32(&lt_input2, in_dims, in_sizes, in_strides),
+ E_SUCCESS);
} else {
lt_input2 = static_cast<dnnLayout_t>(input2_shape.GetCurLayout());
}
@@ -264,14 +272,14 @@ class MklAddNOp : public OpKernel {
bool input2_in_mkl_format = input2_shape.IsMklTensor();
dnnDelete_F32(Eltwise);
if (!input1_in_mkl_format || !input2_in_mkl_format) {
- delete [] in_sizes;
- delete [] in_strides;
+ delete[] in_sizes;
+ delete[] in_strides;
}
if (!input1_in_mkl_format) {
- dnnLayoutDelete_F32(lt_input1);
+ dnnLayoutDelete_F32(lt_input1);
}
if (!input2_in_mkl_format) {
- dnnLayoutDelete_F32(lt_input2);
+ dnnLayoutDelete_F32(lt_input2);
}
}
} MklAddNOpContext;
@@ -303,33 +311,44 @@ class MklAddNOp : public OpKernel {
GetMklShape(ctx, src2_idx, &src2_mkl_shape);
bool input1_in_mkl_format = src1_mkl_shape.IsMklTensor();
bool input2_in_mkl_format = src2_mkl_shape.IsMklTensor();
- int src1_dims_size = input1_in_mkl_format?
- src1_mkl_shape.GetDimension(): src1_tensor.dims();
- int src2_dims_size = input2_in_mkl_format?
- src2_mkl_shape.GetDimension(): src2_tensor.dims();
+ int src1_dims_size = input1_in_mkl_format ? src1_mkl_shape.GetDimension()
+ : src1_tensor.dims();
+ int src2_dims_size = input2_in_mkl_format ? src2_mkl_shape.GetDimension()
+ : src2_tensor.dims();
+ // if the shapes of two tensors are not same raise op error
+ TensorShape src1_shape, src2_shape;
+ src1_shape = src1_tensor.shape();
+ src2_shape = src2_tensor.shape();
+ if (!src1_shape.IsSameSize(src2_shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", this->name(), " of type ",
+ this->type_string(),
+ " must have the same size and shape. Input 0: ",
+ src1_shape.DebugString(),
+ " != input 1: ", src2_shape.DebugString()));
+ }
if (!input1_in_mkl_format && src1_dims_size == 0) {
- Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
- mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- src1_tensor.shape(), mkl_shape_dst);
- float user_i1 = (src1_tensor.scalar<T>()());
- float user_i2 = (src2_tensor.scalar<T>()());
- dst_tensor->scalar<T>()() =
- std::plus<float>{}(user_i1, user_i2);
- return;
- }
+ Tensor* dst_tensor = nullptr;
+ MklShape mkl_shape_dst;
+ mkl_shape_dst.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+ src1_tensor.shape(), mkl_shape_dst);
+ float user_i1 = (src1_tensor.scalar<T>()());
+ float user_i2 = (src2_tensor.scalar<T>()());
+ dst_tensor->scalar<T>()() = std::plus<float>{}(user_i1, user_i2);
+ return;
+ }
// If there is nothing to compute, return.
if (!input1_in_mkl_format && !input2_in_mkl_format) {
if (src1_tensor.shape().num_elements() == 0) {
- Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
- mkl_shape_dst.SetMklTensor(false);
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- src1_tensor.shape(), mkl_shape_dst);
- return;
+ Tensor* dst_tensor = nullptr;
+ MklShape mkl_shape_dst;
+ mkl_shape_dst.SetMklTensor(false);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
+ src1_tensor.shape(), mkl_shape_dst);
+ return;
}
}
@@ -338,7 +357,7 @@ class MklAddNOp : public OpKernel {
MklDnnData<T> src2(&cpu_engine);
MklDnnData<T> dst(&cpu_engine);
- int tmp_size = input1_in_mkl_format ? src2_dims_size: src1_dims_size;
+ int tmp_size = input1_in_mkl_format ? src2_dims_size : src1_dims_size;
memory::dims dims(tmp_size);
memory::dims strides(tmp_size);
memory::desc md1({}, memory::data_undef, memory::format_undef);
@@ -368,21 +387,19 @@ class MklAddNOp : public OpKernel {
md1 = src1_mkl_shape.GetMklLayout();
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
- auto src1_tf_data_format = MklDnnDataFormatToTFDataFormat(
- src1_mkl_data_format);
- auto src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
- src1_tf_data_format);
- md2 = memory::desc(src2_dims, MklDnnType<T>(),
- src1_mkl_data_format);
+ auto src1_tf_data_format =
+ MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
+ auto src2_dims =
+ TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
- auto src2_tf_data_format = MklDnnDataFormatToTFDataFormat(
- src2_mkl_data_format);
- auto src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
- src2_tf_data_format);
- md1 = memory::desc(src1_dims, MklDnnType<T>(),
- src2_mkl_data_format);
+ auto src2_tf_data_format =
+ MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
+ auto src1_dims =
+ TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();
} else {
@@ -456,20 +473,19 @@ class MklAddNOp : public OpKernel {
output_mkl_shape.SetMklTensor(false);
output_tf_shape = src1_tensor.shape();
}
- AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
- output_tf_shape, output_mkl_shape);
+ AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor, output_tf_shape,
+ output_mkl_shape);
dst.SetUsrMemDataHandle(dst_tensor);
// Create Sum op, and submit net for execution.
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
- } catch (mkldnn::error &e) {
+ } catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) +
- ", in file " + string(__FILE__) + ":" +
- std::to_string(__LINE__);
- OP_REQUIRES_OK(ctx, errors::Aborted("Operation received an exception:",
- error_msg));
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ ctx, errors::Aborted("Operation received an exception:", error_msg));
}
}
};
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 001834b13b..4b5f7b8310 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -396,7 +396,7 @@ class MklInputConversionOp : public OpKernel {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> tf_input(&cpu_engine);
auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
- tf_input.SetUsrMem(input_tf_md, &tf_tensor);
+ tf_input.SetUsrMem(input_tf_md, tf_tensor);
// Create reorder between tensorflow layout and Mkl layout.
std::vector<primitive> net;
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index 66bc7dd8ee..95e0404ba8 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -43,7 +43,7 @@ limitations under the License.
using mkldnn::lrn_forward;
using mkldnn::lrn_backward;
using mkldnn::prop_kind;
-using mkldnn::algorithm::lrn_across_channels;
+using mkldnn::lrn_across_channels;
using mkldnn::stream;
#endif
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 896d562933..c46eabdde1 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -17,13 +17,13 @@ limitations under the License.
#ifdef INTEL_MKL
#ifdef INTEL_MKL_DNN
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "mkldnn.h"
#include "mkldnn_types.h"
@@ -31,16 +31,14 @@ limitations under the License.
#include "tensorflow/core/util/mkl_util.h"
#include "mkldnn.hpp"
-using mkldnn::stream;
using mkldnn::prop_kind;
using mkldnn::softmax_forward;
+using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
-
-
template <typename Device, typename T>
class MklSoftmaxOp : public OpKernel {
public:
@@ -60,11 +58,11 @@ class MklSoftmaxOp : public OpKernel {
MklDnnShape src_mkl_shape;
GetMklShape(context, src_idx, &src_mkl_shape);
-
// src_dims is the dimenstion of src_tensor
// dim of the dst will also be same as src_dims
- auto src_tf_shape = src_mkl_shape.IsMklTensor() ?
- src_mkl_shape.GetTfShape() : src_tensor.shape();
+ auto src_tf_shape = src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetTfShape()
+ : src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
@@ -77,10 +75,10 @@ class MklSoftmaxOp : public OpKernel {
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
// layout
- auto src_md = src_mkl_shape.IsMklTensor()
- ? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(),
- memory::format::nc);
+ auto src_md =
+ src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -95,8 +93,8 @@ class MklSoftmaxOp : public OpKernel {
int axis = 1; // axis to which softmax will be applied
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
- auto softmax_fwd_pd = softmax_forward::primitive_desc(softmax_fwd_desc,
- cpu_engine);
+ auto softmax_fwd_pd =
+ softmax_forward::primitive_desc(softmax_fwd_desc, cpu_engine);
// add: output
Tensor* output_tensor = nullptr;
@@ -136,9 +134,9 @@ class MklSoftmaxOp : public OpKernel {
net.push_back(softmax_fwd);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
- string(e.message) + ", in file " + string(__FILE__) +
- ":" + std::to_string(__LINE__);
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@@ -148,7 +146,7 @@ class MklSoftmaxOp : public OpKernel {
/* Register DNN kernels for supported operations and supported types - right now
* it is only Softmax and f32 */
-#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \
+#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER(Name("_MklSoftmax") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
@@ -156,7 +154,6 @@ class MklSoftmaxOp : public OpKernel {
MklSoftmaxOp<CPUDevice, type>);
TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
-
} // namespace tensorflow
#endif // INTEL_MKL_DNN
diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD
index 536b2bdc03..c3d24e50ef 100644
--- a/tensorflow/core/kernels/neon/BUILD
+++ b/tensorflow/core/kernels/neon/BUILD
@@ -39,6 +39,6 @@ tf_kernel_library(
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:ops_util",
- "@gemmlowp//:gemmlowp",
+ "@gemmlowp",
],
)
diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h
index acd58a644f..11f5be7c03 100644
--- a/tensorflow/core/kernels/neon/depthwiseconv_float.h
+++ b/tensorflow/core/kernels/neon/depthwiseconv_float.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
#include "public/gemmlowp.h"
#include "tensorflow/core/kernels/neon/types.h"
@@ -722,4 +722,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
} // end namespace neon
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
diff --git a/tensorflow/core/kernels/neon/types.h b/tensorflow/core/kernels/neon/types.h
index 4ece22f015..05ff1bcc6c 100644
--- a/tensorflow/core/kernels/neon/types.h
+++ b/tensorflow/core/kernels/neon/types.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
+#define TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
#include "tensorflow/core/platform/logging.h"
@@ -70,4 +70,4 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
} // end namespace neon
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
+#endif // TENSORFLOW_CORE_KERNELS_NEON_TYPES_H_
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index 2923c38662..2033fbf5dc 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -139,7 +139,6 @@ class PackOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_PACK);
TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
-TF_CALL_bfloat16(REGISTER_PACK);
TF_CALL_variant(REGISTER_PACK);
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index 6a52a15c93..d4241b5809 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -222,7 +222,7 @@ void DnnPoolingOp<T>::Compute(
output_desc, &output_data)
.ok();
OP_REQUIRES(context, status,
- errors::Internal("cudnn PoolBackward launch failed"));
+ errors::Internal("cudnn PoolForward launch failed"));
if (data_format == FORMAT_NHWC) {
/// Transform the output data from NCHW back to NHWC
diff --git a/tensorflow/core/kernels/population_count_op.h b/tensorflow/core/kernels/population_count_op.h
index de89582e13..2c98129673 100644
--- a/tensorflow/core/kernels/population_count_op.h
+++ b/tensorflow/core/kernels/population_count_op.h
@@ -14,8 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -35,4 +35,4 @@ struct PopulationCount {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_
diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h
index 7c18496357..9fafe6bb65 100644
--- a/tensorflow/core/kernels/quantization_utils.h
+++ b/tensorflow/core/kernels/quantization_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
#define EIGEN_USE_THREADS
@@ -956,4 +956,4 @@ class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h
index bb2a21720f..c9cc04ed1b 100644
--- a/tensorflow/core/kernels/reference_gemm.h
+++ b/tensorflow/core/kernels/reference_gemm.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
+#define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
#include <stdlib.h>
@@ -92,4 +92,4 @@ void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
}
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
+#endif // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
index 3fa052108e..7de45eaaa1 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -86,4 +86,4 @@ class TestRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_OP_TEST_UTILS_H_
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
index 541c26baaf..f047144278 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
#include <unordered_map>
#include <unordered_set>
@@ -312,4 +312,4 @@ class RemoteFusedGraphExecuteUtils {
};
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_REMOTE_FUSED_GRAPH_EXECUTE_UTILS_H_
diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h
index ed583afd13..6777748b63 100644
--- a/tensorflow/core/kernels/reshape_util.h
+++ b/tensorflow/core/kernels/reshape_util.h
@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -28,4 +28,4 @@ void Reshape(OpKernelContext *context, const Tensor &input_indices_in,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
index cffc326174..c6c9d4e658 100644
--- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
// Functor definitions for ScatterND ops, must be compilable by nvcc.
@@ -257,4 +257,4 @@ REGISTER_SCATTER_ND_MATH_SYCL(int32);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
index 31f74671ca..a3c21edc15 100644
--- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
@@ -55,6 +55,27 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
}
};
+// Specializations for std::complex, updating real and imaginary part
+// individually. Even though this is not an atomic op anymore, it is safe
+// because there is only one type of op per kernel.
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ T* ptr = reinterpret_cast<T*>(out);
+ CudaAtomicAdd(ptr, val.real());
+ CudaAtomicAdd(ptr, val.imag());
+ }
+};
+
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
+ }
+};
+
} // namespace
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index b10bea72ba..bcdd42c80c 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -98,4 +98,4 @@ struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, I
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/slice_op_cpu_impl.h b/tensorflow/core/kernels/slice_op_cpu_impl.h
index 58dc7df3e0..47f1d5342a 100644
--- a/tensorflow/core/kernels/slice_op_cpu_impl.h
+++ b/tensorflow/core/kernels/slice_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
#define EIGEN_USE_THREADS
@@ -47,4 +47,4 @@ DEFINE_SYCL_KERNELS(int32);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/spectrogram.h b/tensorflow/core/kernels/spectrogram.h
index 5476a0a961..fef0e64942 100644
--- a/tensorflow/core/kernels/spectrogram.h
+++ b/tensorflow/core/kernels/spectrogram.h
@@ -28,8 +28,8 @@ limitations under the License.
// window = hann(window_length_samples, 'periodic');
// S = abs(spectrogram(audio, window, overlap_samples)).^2;
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
+#define TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
#include <complex>
#include <deque>
@@ -109,4 +109,4 @@ class Spectrogram {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc
index 046f6344df..872a6e9d1b 100644
--- a/tensorflow/core/kernels/spectrogram_test_utils.cc
+++ b/tensorflow/core/kernels/spectrogram_test_utils.cc
@@ -70,10 +70,24 @@ bool ReadRawFloatFileToComplexVector(
int offset = 0;
const int end = data_string.size();
while (offset < end) {
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ char arr[4];
+ for (int i = 0; i < kBytesPerValue; ++i) {
+ arr[3 - i] = *(data_string.data() + offset + i);
+ }
+ memcpy(&real_out, arr, kBytesPerValue);
+ offset += kBytesPerValue;
+ for (int i = 0; i < kBytesPerValue; ++i) {
+ arr[3 - i] = *(data_string.data() + offset + i);
+ }
+ memcpy(&imag_out, arr, kBytesPerValue);
+ offset += kBytesPerValue;
+#else
memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
offset += kBytesPerValue;
memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
offset += kBytesPerValue;
+#endif
if (row_counter >= row_length) {
data->push_back(data_row);
data_row.clear();
diff --git a/tensorflow/core/kernels/spectrogram_test_utils.h b/tensorflow/core/kernels/spectrogram_test_utils.h
index 59a903549e..d4187076e7 100644
--- a/tensorflow/core/kernels/spectrogram_test_utils.h
+++ b/tensorflow/core/kernels/spectrogram_test_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
#include <complex>
#include <string>
@@ -78,4 +78,4 @@ void SineWave(int sample_rate, float frequency, float duration_seconds,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index 41cbece1d6..d317a8d33d 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -42,11 +42,16 @@ class CreateSummaryFileWriterOp : public OpKernel {
const int32 flush_millis = tmp->scalar<int32>()();
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
const string filename_suffix = tmp->scalar<string>()();
- SummaryWriterInterface* s;
- OP_REQUIRES_OK(ctx,
- CreateSummaryFileWriter(max_queue, flush_millis, logdir,
- filename_suffix, ctx->env(), &s));
- OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+
+ SummaryWriterInterface* s = nullptr;
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
+ ctx, HandleFromInput(ctx, 0), &s,
+ [max_queue, flush_millis, logdir, filename_suffix,
+ ctx](SummaryWriterInterface** s) {
+ return CreateSummaryFileWriter(
+ max_queue, flush_millis, logdir,
+ filename_suffix, ctx->env(), s);
+ }));
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
@@ -66,17 +71,23 @@ class CreateSummaryDbWriterOp : public OpKernel {
const string run_name = tmp->scalar<string>()();
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
const string user_name = tmp->scalar<string>()();
- SummaryWriterInterface* s;
- Sqlite* db;
- OP_REQUIRES_OK(ctx, Sqlite::Open(db_uri,
- SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
- &db));
- core::ScopedUnref unref(db);
- OP_REQUIRES_OK(ctx, SetupTensorboardSqliteDb(db));
+
+ SummaryWriterInterface* s = nullptr;
OP_REQUIRES_OK(
- ctx, CreateSummaryDbWriter(db, experiment_name,
- run_name, user_name, ctx->env(), &s));
- OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+ ctx,
+ LookupOrCreateResource<SummaryWriterInterface>(
+ ctx, HandleFromInput(ctx, 0), &s,
+ [db_uri, experiment_name, run_name, user_name,
+ ctx](SummaryWriterInterface** s) {
+ Sqlite* db;
+ TF_RETURN_IF_ERROR(Sqlite::Open(
+ db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db));
+ core::ScopedUnref unref(db);
+ TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
+ TF_RETURN_IF_ERROR(CreateSummaryDbWriter(
+ db, experiment_name, run_name, user_name, ctx->env(), s));
+ return Status::OK();
+ }));
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
@@ -267,8 +278,6 @@ class WriteAudioSummaryOp : public OpKernel {
private:
int max_outputs_;
- bool has_sample_rate_attr_;
- float sample_rate_attr_;
};
REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
WriteAudioSummaryOp);
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
index dedc2da60b..8c3a58b108 100644
--- a/tensorflow/core/kernels/svd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
- CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
+ CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
CudaAtomicAdd(V + batch, v);
}
diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h
index a6eed4935d..054b31ef9e 100644
--- a/tensorflow/core/kernels/tile_ops_cpu_impl.h
+++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
#define EIGEN_USE_THREADS
@@ -68,4 +68,4 @@ TF_CALL_int64(DEFINE_TYPE);
} // end namespace functor
} // end namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/tile_ops_gpu_impl.h b/tensorflow/core/kernels/tile_ops_gpu_impl.h
index 592f99e9b7..8da337dabd 100644
--- a/tensorflow/core/kernels/tile_ops_gpu_impl.h
+++ b/tensorflow/core/kernels/tile_ops_gpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
// Header used to split up compilation of GPU tile ops. For each type you want
// to have tile ops, create a .cu.cc file containing
@@ -56,4 +56,4 @@ limitations under the License.
} \
}
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index 41b73fdaf4..5198df7e16 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -88,6 +88,18 @@ struct Transpose<CPUDevice, T, conjugate> {
internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
out);
break;
+ case 6:
+ internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
+ out);
+ break;
+ case 7:
+ internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
+ out);
+ break;
+ case 8:
+ internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
+ out);
+ break;
default:
TransposeSimple<T, conjugate>(d, in, perm, out);
break;
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index 493dac9a7c..d6a237d6c1 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -201,6 +201,27 @@ struct Transpose<GPUDevice, T, conjugate> {
out);
}
break;
+ case 6:
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 6>(d, in, perm, conjugate,
+ out);
+ }
+ break;
+ case 7:
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 7>(d, in, perm, conjugate,
+ out);
+ }
+ break;
+ case 8:
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 8>(d, in, perm, conjugate,
+ out);
+ }
+ break;
default:
internal::TransposeSimple<T, conjugate>(d, in, perm, out);
break;
diff --git a/tensorflow/core/kernels/winograd_transform.h b/tensorflow/core/kernels/winograd_transform.h
index 5caee9fdc1..d22710e503 100644
--- a/tensorflow/core/kernels/winograd_transform.h
+++ b/tensorflow/core/kernels/winograd_transform.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
+#ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
+#define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
#include "tensorflow/core/kernels/deep_conv2d.h"
@@ -374,4 +374,4 @@ void WinogradTransform<T>::GetOutputTransformMatrix(const int64 rows,
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
+#endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index dc21cee3a8..0f8d027caa 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -67,10 +67,12 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
// Try to reuse the logits_in buffer for the backprop output.
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 1, logits_in.shape(), &back_out));
- functor::XentFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ if (logits_in.dim_size(0) > 0) {
+ functor::XentFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
}
};
diff --git a/tensorflow/core/kernels/xsmm_conv2d.h b/tensorflow/core/kernels/xsmm_conv2d.h
index b439511dc7..003291329a 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.h
+++ b/tensorflow/core/kernels/xsmm_conv2d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
+#define TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -57,4 +57,4 @@ struct XsmmBkwFilterConv2D {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
+#endif // TENSORFLOW_CORE_KERNELS_XSMM_CONV2D_H_
diff --git a/tensorflow/core/lib/core/bitmap.h b/tensorflow/core/lib/core/bitmap.h
index b30479fa1b..8ff1e666b4 100644
--- a/tensorflow/core/lib/core/bitmap.h
+++ b/tensorflow/core/lib/core/bitmap.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
+#define TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -103,4 +103,4 @@ inline void Bitmap::clear(size_t i) {
} // namespace core
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_BITMAP_H_
diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc
index b5c0d9f621..0f6999c88f 100644
--- a/tensorflow/core/lib/gif/gif_io.cc
+++ b/tensorflow/core/lib/gif/gif_io.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/lib/gif/gif_io.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/gif.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
@@ -44,7 +45,8 @@ int input_callback(GifFileType* gif_file, GifByteType* buf, int size) {
}
uint8* Decode(const void* srcdata, int datasize,
- std::function<uint8*(int, int, int, int)> allocate_output) {
+ const std::function<uint8*(int, int, int, int)>& allocate_output,
+ string* error_string) {
int error_code = D_GIF_SUCCEEDED;
InputBufferInfo info = {reinterpret_cast<const uint8*>(srcdata), datasize};
GifFileType* gif_file =
@@ -57,17 +59,17 @@ uint8* Decode(const void* srcdata, int datasize,
}
});
if (error_code != D_GIF_SUCCEEDED) {
- LOG(ERROR) << "Fail to open gif file, reason: "
- << GifErrorString(error_code);
+ *error_string = strings::StrCat("failed to open gif file: ",
+ GifErrorString(error_code));
return nullptr;
}
if (DGifSlurp(gif_file) != GIF_OK) {
- LOG(ERROR) << "Fail to slurp gif file, reason: "
- << GifErrorString(gif_file->Error);
+ *error_string = strings::StrCat("failed to slurp gif file: ",
+ GifErrorString(gif_file->Error));
return nullptr;
}
if (gif_file->ImageCount <= 0) {
- LOG(ERROR) << "Gif file does not contain any image";
+ *error_string = strings::StrCat("gif file does not contain any image");
return nullptr;
}
@@ -83,7 +85,7 @@ uint8* Decode(const void* srcdata, int datasize,
GifImageDesc* img_desc = &this_image->ImageDesc;
if (img_desc->Left != 0 || img_desc->Top != 0 || img_desc->Width != width ||
img_desc->Height != height) {
- LOG(ERROR) << "Can't process optimized gif.";
+ *error_string = strings::StrCat("can't process optimized gif");
return nullptr;
}
diff --git a/tensorflow/core/lib/gif/gif_io.h b/tensorflow/core/lib/gif/gif_io.h
index 5399e6a538..0a7967a5a1 100644
--- a/tensorflow/core/lib/gif/gif_io.h
+++ b/tensorflow/core/lib/gif/gif_io.h
@@ -43,7 +43,8 @@ namespace tensorflow {
namespace gif {
uint8* Decode(const void* srcdata, int datasize,
- std::function<uint8*(int, int, int, int)> allocate_output);
+ const std::function<uint8*(int, int, int, int)>& allocate_output,
+ string* error_string);
} // namespace gif
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/compactptrset.h b/tensorflow/core/lib/gtl/compactptrset.h
index 1d4d6cc8d2..d3d23b94aa 100644
--- a/tensorflow/core/lib/gtl/compactptrset.h
+++ b/tensorflow/core/lib/gtl/compactptrset.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
+#define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
#include <type_traits>
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -205,4 +205,4 @@ class CompactPointerSet {
} // namespace gtl
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_
diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h
index 6dd67ad2ea..889d2ddaa6 100644
--- a/tensorflow/core/lib/gtl/flatmap.h
+++ b/tensorflow/core/lib/gtl/flatmap.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
+#define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
#include <stddef.h>
#include <functional>
@@ -379,4 +379,4 @@ class FlatMap {
} // namespace gtl
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_
diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h
index bb405b327a..0d7e7487fc 100644
--- a/tensorflow/core/lib/gtl/flatrep.h
+++ b/tensorflow/core/lib/gtl/flatrep.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
+#define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
#include <string.h>
#include <utility>
@@ -328,4 +328,4 @@ class FlatRep {
} // namespace gtl
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h
index 2b7f31ab22..f31e3abe41 100644
--- a/tensorflow/core/lib/gtl/flatset.h
+++ b/tensorflow/core/lib/gtl/flatset.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
+#define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
#include <stddef.h>
#include <functional>
@@ -278,4 +278,4 @@ class FlatSet {
} // namespace gtl
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h
index 2b824f35f8..924619f40f 100644
--- a/tensorflow/core/lib/io/buffered_inputstream.h
+++ b/tensorflow/core/lib/io/buffered_inputstream.h
@@ -104,4 +104,4 @@ class BufferedInputStream : public InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#endif // TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h
index 7a0c5c12a7..ef90c60a3a 100644
--- a/tensorflow/core/lib/io/compression.h
+++ b/tensorflow/core/lib/io/compression.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
+#define TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
namespace tensorflow {
namespace io {
@@ -27,4 +27,4 @@ extern const char kGzip[];
}
}
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
+#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_
diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h
index 096248693b..3083d20776 100644
--- a/tensorflow/core/lib/io/inputstream_interface.h
+++ b/tensorflow/core/lib/io/inputstream_interface.h
@@ -54,4 +54,4 @@ class InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
+#endif // TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc
index 8b8c1392a1..09336e79cd 100644
--- a/tensorflow/core/lib/io/random_inputstream.cc
+++ b/tensorflow/core/lib/io/random_inputstream.cc
@@ -57,6 +57,43 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
return Status::OK();
}
+// To limit memory usage, the default implementation of SkipNBytes() only reads
+// 8MB at a time.
+static constexpr int64 kMaxSkipSize = 8 * 1024 * 1024;
+
+Status RandomAccessInputStream::SkipNBytes(int64 bytes_to_skip) {
+ if (bytes_to_skip < 0) {
+ return errors::InvalidArgument("Can't skip a negative number of bytes");
+ }
+ std::unique_ptr<char[]> scratch(new char[kMaxSkipSize]);
+ // Try to read 1 bytes first, if we could complete the read then EOF is
+ // not reached yet and we could return.
+ if (bytes_to_skip > 0) {
+ StringPiece data;
+ Status s = file_->Read(pos_ + bytes_to_skip - 1, 1, &data, scratch.get());
+ if ((s.ok() || errors::IsOutOfRange(s)) && data.size() == 1) {
+ pos_ += bytes_to_skip;
+ return Status::OK();
+ }
+ }
+ // Read kDefaultSkipSize at a time till bytes_to_skip.
+ while (bytes_to_skip > 0) {
+ int64 bytes_to_read = std::min<int64>(kMaxSkipSize, bytes_to_skip);
+ StringPiece data;
+ Status s = file_->Read(pos_, bytes_to_read, &data, scratch.get());
+ if (s.ok() || errors::IsOutOfRange(s)) {
+ pos_ += data.size();
+ } else {
+ return s;
+ }
+ if (data.size() < bytes_to_read) {
+ return errors::OutOfRange("reached end of file");
+ }
+ bytes_to_skip -= bytes_to_read;
+ }
+ return Status::OK();
+}
+
int64 RandomAccessInputStream::Tell() const { return pos_; }
} // namespace io
diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h
index 09ebe9ba49..bdbdbd71ff 100644
--- a/tensorflow/core/lib/io/random_inputstream.h
+++ b/tensorflow/core/lib/io/random_inputstream.h
@@ -34,6 +34,8 @@ class RandomAccessInputStream : public InputStreamInterface {
Status ReadNBytes(int64 bytes_to_read, string* result) override;
+ Status SkipNBytes(int64 bytes_to_skip) override;
+
int64 Tell() const override;
Status Seek(int64 position) {
diff --git a/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h b/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h
index 5d330a2c5a..5aea503846 100644
--- a/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h
+++ b/tensorflow/core/lib/io/snappy/snappy_outputbuffer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
+#define TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
@@ -117,4 +117,4 @@ class SnappyOutputBuffer {
} // namespace io
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_SNAPPY_OUTPUTBUFFER_H_
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 5cad2e9457..3d86d89a99 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
+#define TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
#include <zlib.h>
@@ -143,4 +143,4 @@ class ZlibOutputBuffer : public WritableFile {
} // namespace io
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSED_OUTPUTBUFFER_H_
diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h
index acdb0d86ed..e200981609 100644
--- a/tensorflow/core/lib/monitoring/collected_metrics.h
+++ b/tensorflow/core/lib/monitoring/collected_metrics.h
@@ -17,8 +17,8 @@ limitations under the License.
// These are to be used only by the CollectionRegistry and exporters which
// collect metrics using the CollectionRegistry.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
#include <map>
#include <memory>
@@ -151,4 +151,4 @@ struct CollectedMetrics {
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 2c8e250c56..63cc0f550d 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
#include <map>
#include <memory>
@@ -356,4 +356,4 @@ MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h
index 7240348a9b..8ff810db41 100644
--- a/tensorflow/core/lib/monitoring/counter.h
+++ b/tensorflow/core/lib/monitoring/counter.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
// We replace this implementation with a null implementation for mobile
// platforms.
@@ -172,4 +172,4 @@ CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
} // namespace tensorflow
#endif // IS_MOBILE_PLATFORM
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h
index ec978a9193..ee9a862f40 100644
--- a/tensorflow/core/lib/monitoring/gauge.h
+++ b/tensorflow/core/lib/monitoring/gauge.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
// We replace this implementation with a null implementation for mobile
// platforms.
@@ -241,4 +241,4 @@ GaugeCell<ValueType>* Gauge<ValueType, NumLabels>::GetCell(
} // namespace tensorflow
#endif // IS_MOBILE_PLATFORM
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index f046842618..5ecadcc427 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
#include <array>
#include <vector>
@@ -139,4 +139,4 @@ class MetricDef : public AbstractMetricDef {
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
diff --git a/tensorflow/core/lib/monitoring/mobile_counter.h b/tensorflow/core/lib/monitoring/mobile_counter.h
index c30bfe026f..c297d843d2 100644
--- a/tensorflow/core/lib/monitoring/mobile_counter.h
+++ b/tensorflow/core/lib/monitoring/mobile_counter.h
@@ -15,8 +15,8 @@ limitations under the License.
// Null implementation of the Counter metric for mobile platforms.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -64,4 +64,4 @@ class Counter {
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_COUNTER_H_
diff --git a/tensorflow/core/lib/monitoring/mobile_gauge.h b/tensorflow/core/lib/monitoring/mobile_gauge.h
index ac13ad35c0..a03b41aef3 100644
--- a/tensorflow/core/lib/monitoring/mobile_gauge.h
+++ b/tensorflow/core/lib/monitoring/mobile_gauge.h
@@ -15,8 +15,8 @@ limitations under the License.
// Null implementation of the Gauge metric for mobile platforms.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -69,4 +69,4 @@ class Gauge {
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_
diff --git a/tensorflow/core/lib/monitoring/mobile_sampler.h b/tensorflow/core/lib/monitoring/mobile_sampler.h
index cf390e5c7f..77310dd619 100644
--- a/tensorflow/core/lib/monitoring/mobile_sampler.h
+++ b/tensorflow/core/lib/monitoring/mobile_sampler.h
@@ -15,8 +15,8 @@ limitations under the License.
// Null implementation of the Sampler metric for mobile platforms.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
#include <memory>
@@ -98,4 +98,4 @@ class Sampler {
} // namespace monitoring
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_MOBILE_SAMPLER_H_
diff --git a/tensorflow/core/lib/monitoring/sampler.h b/tensorflow/core/lib/monitoring/sampler.h
index c7a05428e2..a4f397f556 100644
--- a/tensorflow/core/lib/monitoring/sampler.h
+++ b/tensorflow/core/lib/monitoring/sampler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
+#define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
// We replace this implementation with a null implementation for mobile
// platforms.
@@ -215,4 +215,4 @@ SamplerCell* Sampler<NumLabels>::GetCell(const Labels&... labels)
} // namespace tensorflow
#endif // IS_MOBILE_PLATFORM
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
+#endif // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h
index ed6d0af010..05dbda6e15 100644
--- a/tensorflow/core/lib/strings/proto_text_util.h
+++ b/tensorflow/core/lib/strings/proto_text_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
@@ -164,4 +164,4 @@ bool ProtoParseStringLiteralFromScanner(Scanner* scanner, string* value);
} // namespace strings
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_
diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc
new file mode 100644
index 0000000000..a64582acee
--- /dev/null
+++ b/tensorflow/core/ops/batch_ops.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Batch")
+ .Input("in_tensors: T")
+ .Output("batched_tensors: T")
+ .Output("batch_index: int64")
+ .Output("id: int64")
+ .Attr("num_batch_threads: int")
+ .Attr("max_batch_size: int")
+ .Attr("batch_timeout_micros: int")
+ .Attr("allowed_batch_sizes: list(int) = []")
+ .Attr("grad_timeout_micros: int")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("batching_queue: string = ''")
+ .Attr("T: list(type)")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<shape_inference::ShapeHandle> in_shapes;
+ TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes));
+ std::vector<shape_inference::ShapeHandle> out_shapes(in_shapes.size());
+ for (int i = 0; i < in_shapes.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i]));
+ }
+ TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes));
+ TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()}));
+ TF_RETURN_IF_ERROR(c->set_output(
+ "batch_index",
+ {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()),
+ shape_inference::DimensionOrConstant(3)})}));
+ return Status::OK();
+ });
+
+REGISTER_OP("Unbatch")
+ .Input("batched_tensor: T")
+ .Input("batch_index: int64")
+ .Input("id: int64")
+ .Output("unbatched_tensor: T")
+ .Attr("timeout_micros: int")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("T: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle out_shape;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape));
+ c->set_output(0, out_shape);
+ return Status::OK();
+ });
+
+REGISTER_OP("UnbatchGrad")
+ .Input("original_input: T")
+ .Input("batch_index: int64")
+ .Input("grad: T")
+ .Input("id: int64")
+ .Output("batched_grad: T")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("T: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2))));
+ return Status::OK();
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 08b685319e..65ab81931a 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -7904,6 +7904,76 @@ op {
}
}
op {
+ name: "Batch"
+ input_arg {
+ name: "in_tensors"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "batched_tensors"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ attr {
+ name: "num_batch_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_batch_size"
+ type: "int"
+ }
+ attr {
+ name: "batch_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "grad_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "batching_queue"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchCholesky"
input_arg {
name: "input"
@@ -18753,6 +18823,57 @@ op {
is_stateful: true
}
op {
+ name: "FixedLengthRecordReader"
+ output_arg {
+ name: "reader_handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "header_bytes"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "record_bytes"
+ type: "int"
+ }
+ attr {
+ name: "footer_bytes"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "hop_bytes"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "FixedLengthRecordReaderV2"
output_arg {
name: "reader_handle"
@@ -21315,6 +21436,32 @@ op {
is_stateful: true
}
op {
+ name: "IdentityReader"
+ output_arg {
+ name: "reader_handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "IdentityReaderV2"
output_arg {
name: "reader_handle"
@@ -22533,6 +22680,30 @@ op {
is_stateful: true
}
op {
+ name: "IteratorGetNextSync"
+ input_arg {
+ name: "iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "IteratorSetStatsAggregator"
input_arg {
name: "iterator_handle"
@@ -38475,6 +38646,46 @@ op {
}
}
op {
+ name: "ResizeBilinear"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "size"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "resized_images"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT8
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_UINT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "align_corners"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ResizeBilinearGrad"
input_arg {
name: "grads"
@@ -38508,6 +38719,40 @@ op {
}
}
op {
+ name: "ResizeBilinearGrad"
+ input_arg {
+ name: "grads"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "original_image"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "align_corners"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ResizeNearestNeighbor"
input_arg {
name: "images"
@@ -61789,6 +62034,39 @@ op {
is_stateful: true
}
op {
+ name: "TFRecordReader"
+ output_arg {
+ name: "reader_handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "compression_type"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "TFRecordReaderV2"
output_arg {
name: "reader_handle"
@@ -62190,6 +62468,16 @@ op {
}
}
op {
+ name: "TensorArrayCloseV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArrayCloseV3"
input_arg {
name: "handle"
@@ -62367,6 +62655,41 @@ op {
}
}
op {
+ name: "TensorArrayGatherV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "value"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "element_shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArrayGatherV3"
input_arg {
name: "handle"
@@ -62444,6 +62767,29 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayGradV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "grad_handle"
+ type: DT_STRING
+ }
+ attr {
+ name: "source"
+ type: "string"
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayGradV3"
input_arg {
name: "handle"
@@ -62550,6 +62896,32 @@ op {
}
}
op {
+ name: "TensorArrayReadV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "value"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArrayReadV3"
input_arg {
name: "handle"
@@ -62632,6 +63004,36 @@ op {
}
}
op {
+ name: "TensorArrayScatterV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArrayScatterV3"
input_arg {
name: "handle"
@@ -62694,6 +63096,24 @@ op {
}
}
op {
+ name: "TensorArraySizeV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "size"
+ type: DT_INT32
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArraySizeV3"
input_arg {
name: "handle"
@@ -62768,6 +63188,36 @@ op {
}
}
op {
+ name: "TensorArraySplitV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "lengths"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArraySplitV3"
input_arg {
name: "handle"
@@ -62869,6 +63319,55 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayV2"
+ input_arg {
+ name: "size"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "element_shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ attr {
+ name: "dynamic_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "clear_after_read"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "tensor_array_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayV3"
input_arg {
name: "size"
@@ -63034,6 +63533,36 @@ op {
}
}
op {
+ name: "TensorArrayWriteV2"
+ input_arg {
+ name: "handle"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ deprecation {
+ version: 26
+ }
+}
+op {
name: "TensorArrayWriteV3"
input_arg {
name: "handle"
@@ -63086,6 +63615,27 @@ op {
is_stateful: true
}
op {
+ name: "TensorListElementShape"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListFromTensor"
input_arg {
name: "tensor"
@@ -63115,6 +63665,25 @@ op {
}
}
op {
+ name: "TensorListGetItem"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "item"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListLength"
input_arg {
name: "input_handle"
@@ -63164,6 +63733,58 @@ op {
}
}
op {
+ name: "TensorListReserve"
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ input_arg {
+ name: "num_elements"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "TensorListSetItem"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "item"
+ type_attr: "element_dtype"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListStack"
input_arg {
name: "input_handle"
@@ -63320,6 +63941,39 @@ op {
is_stateful: true
}
op {
+ name: "TextLineReader"
+ output_arg {
+ name: "reader_handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "skip_header_lines"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ deprecation {
+ version: 26
+ }
+ is_stateful: true
+}
+op {
name: "TextLineReaderV2"
output_arg {
name: "reader_handle"
@@ -64141,6 +64795,88 @@ op {
is_stateful: true
}
op {
+ name: "Unbatch"
+ input_arg {
+ name: "batched_tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "unbatched_tensor"
+ type_attr: "T"
+ }
+ attr {
+ name: "timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "UnbatchGrad"
+ input_arg {
+ name: "original_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "batched_grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index cf949ed647..4f946fb3ca 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -171,29 +171,10 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
return Status::OK();
}
-Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
- ShapeHandle handle;
- DimensionHandle unused_handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
- for (int i = 1; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
- }
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->Scalar());
- }
- return Status::OK();
-}
-
Status TwoElementOutput(InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
}
-
-Status ScalarOutput(InferenceContext* c) {
- c->set_output(0, c->Scalar());
- return Status::OK();
-}
} // namespace
REGISTER_OP("RandomShuffleQueue")
@@ -787,7 +768,6 @@ REGISTER_OP("TensorArray")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayV2")
.Input("size: int32")
.Attr("dtype: type")
@@ -802,7 +782,8 @@ REGISTER_OP("TensorArrayV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
c->set_output(0, c->Vector(2));
return Status::OK();
- });
+ })
+ .Deprecated(26, "Use TensorArrayV3");
REGISTER_OP("TensorArrayGrad")
.Input("handle: string")
.Input("flow_in: float")
@@ -811,7 +792,6 @@ REGISTER_OP("TensorArrayGrad")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGradV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayGradV2")
.Input("handle: string")
.Input("flow_in: float")
@@ -825,7 +805,8 @@ REGISTER_OP("TensorArrayGradV2")
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
c->set_output(0, c->Vector(2));
return Status::OK();
- });
+ })
+ .Deprecated(26, "Use TensorArrayGradV3");
REGISTER_OP("TensorArrayWrite")
.Input("handle: Ref(string)")
.Input("index: int32")
@@ -835,7 +816,6 @@ REGISTER_OP("TensorArrayWrite")
.Attr("T: type")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayWriteV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayWriteV2")
.Input("handle: string")
.Input("index: int32")
@@ -853,7 +833,8 @@ REGISTER_OP("TensorArrayWriteV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArrayWriteV3");
REGISTER_OP("TensorArrayRead")
.Input("handle: Ref(string)")
.Input("index: int32")
@@ -862,7 +843,6 @@ REGISTER_OP("TensorArrayRead")
.Attr("dtype: type")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayReadV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayReadV2")
.Input("handle: string")
.Input("index: int32")
@@ -878,7 +858,8 @@ REGISTER_OP("TensorArrayReadV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
return shape_inference::UnknownShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArrayReadV3");
REGISTER_OP("TensorArrayPack")
.Input("handle: Ref(string)")
.Input("flow_in: float")
@@ -904,7 +885,6 @@ REGISTER_OP("TensorArrayGather")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArrayGatherV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayGatherV2")
.Input("handle: string")
.Input("indices: int32")
@@ -920,7 +900,8 @@ REGISTER_OP("TensorArrayGatherV2")
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
return shape_inference::UnknownShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArrayGatherV3");
REGISTER_OP("TensorArrayScatter")
.Input("handle: Ref(string)")
.Input("indices: int32")
@@ -930,7 +911,6 @@ REGISTER_OP("TensorArrayScatter")
.Attr("T: type")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(19, "Use TensorArrayGradV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayScatterV2")
.Input("handle: string")
.Input("indices: int32")
@@ -946,7 +926,8 @@ REGISTER_OP("TensorArrayScatterV2")
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArrayScatterV3");
REGISTER_OP("TensorArrayConcat")
.Input("handle: Ref(string)")
.Input("flow_in: float")
@@ -983,7 +964,6 @@ REGISTER_OP("TensorArraySplit")
.Attr("T: type")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArraySplitV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArraySplitV2")
.Input("handle: string")
.Input("value: T")
@@ -1000,14 +980,14 @@ REGISTER_OP("TensorArraySplitV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
return shape_inference::ScalarShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArraySplitV3");
REGISTER_OP("TensorArraySize")
.Input("handle: Ref(string)")
.Input("flow_in: float")
.Output("size: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Deprecated(16, "Use TensorArraySizeV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArraySizeV2")
.Input("handle: string")
.Input("flow_in: float")
@@ -1018,12 +998,12 @@ REGISTER_OP("TensorArraySizeV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
return shape_inference::ScalarShape(c);
- });
+ })
+ .Deprecated(26, "Use TensorArraySizeV3");
REGISTER_OP("TensorArrayClose")
.Input("handle: Ref(string)")
.SetShapeFn([](InferenceContext* c) { return Status::OK(); })
.Deprecated(16, "Use TensorArrayCloseV3");
-// TODO(cwhipkey): mark this deprecated in favor of V3.
REGISTER_OP("TensorArrayCloseV2")
.Input("handle: string")
.SetShapeFn([](InferenceContext* c) {
@@ -1032,7 +1012,8 @@ REGISTER_OP("TensorArrayCloseV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
return Status::OK();
- });
+ })
+ .Deprecated(26, "Use TensorArrayCloseV3");
// --------------------------------------------------------------------------
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index b86816bb54..2cae814eab 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -409,53 +409,49 @@ REGISTER_OP("OneShotIterator")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+namespace {
+
+Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("IteratorGetNext")
.Input("iterator: resource")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- });
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("IteratorGetNextSync")
+ .Input("iterator: resource")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
REGISTER_OP("DatasetToSingleElement")
.Input("dataset: variant")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- });
+ .SetShapeFn(IteratorGetNextShapeFn);
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 31cc662d21..ef2ac267cc 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -25,42 +25,6 @@ using shape_inference::ShapeHandle;
namespace {
-const char kDecodeJpegCommonDocStr[] = R"doc(
-The attr `channels` indicates the desired number of color channels for the
-decoded image.
-
-Accepted values are:
-
-* 0: Use the number of channels in the JPEG-encoded image.
-* 1: output a grayscale image.
-* 3: output an RGB image.
-
-If needed, the JPEG-encoded image is transformed to match the requested number
-of color channels.
-
-The attr `ratio` allows downscaling the image by an integer factor during
-decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
-downscaling the image later.
-
-)doc";
-
-const char kDecodeJpegCommonParamsDocStr[] = R"doc(
-channels: Number of color channels for the decoded image.
-ratio: Downscaling ratio.
-fancy_upscaling: If true use a slower but nicer upscaling of the
- chroma planes (yuv420/422 only).
-try_recover_truncated: If true try to recover an image from truncated input.
-acceptable_fraction: The minimum required fraction of lines before a truncated
- input is accepted.
-dct_method: string specifying a hint about the algorithm used for
- decompression. Defaults to "" which maps to a system-specific
- default. Currently valid values are ["INTEGER_FAST",
- "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
- jpeg library changes to a version that does not have that specific
- option.)
-image: 3-D with shape `[height, width, channels]`..
-)doc";
-
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
// height and width come from the size_tensor.
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
@@ -181,7 +145,9 @@ REGISTER_OP("ResizeBilinear")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
- .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
+ .Attr(
+ "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
+ "float, double}")
.Attr("align_corners: bool = false")
.SetShapeFn(ResizeShapeFn);
@@ -212,7 +178,7 @@ REGISTER_OP("ResizeBilinearGrad")
.Input("grads: float")
.Input("original_image: T")
.Output("output: T")
- .Attr("T: {float, half, double}")
+ .Attr("T: {float, bfloat16, half, double}")
.Attr("align_corners: bool = false")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(1));
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 21f0d02ff2..7db4d0c4b6 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -272,14 +272,14 @@ REGISTER_OP("WholeFileReaderV2")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
-// TODO(cwhipkey): mark this deprecated in favor of V2.
REGISTER_OP("TextLineReader")
.Output("reader_handle: Ref(string)")
.Attr("skip_header_lines: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
- .SetShapeFn(TwoElementOutput);
+ .SetShapeFn(TwoElementOutput)
+ .Deprecated(26, "Use TextLineReaderV2");
REGISTER_OP("TextLineReaderV2")
.Output("reader_handle: resource")
@@ -289,7 +289,6 @@ REGISTER_OP("TextLineReaderV2")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
-// TODO(cwhipkey): mark this deprecated in favor of V2.
REGISTER_OP("FixedLengthRecordReader")
.Output("reader_handle: Ref(string)")
.Attr("header_bytes: int = 0")
@@ -299,7 +298,8 @@ REGISTER_OP("FixedLengthRecordReader")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
- .SetShapeFn(TwoElementOutput);
+ .SetShapeFn(TwoElementOutput)
+ .Deprecated(26, "Use FixedLengthRecordReaderV2");
REGISTER_OP("FixedLengthRecordReaderV2")
.Output("reader_handle: resource")
@@ -313,14 +313,14 @@ REGISTER_OP("FixedLengthRecordReaderV2")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
-// TODO(cwhipkey): mark this deprecated in favor of V2.
REGISTER_OP("TFRecordReader")
.Output("reader_handle: Ref(string)")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("compression_type: string = ''")
.SetIsStateful()
- .SetShapeFn(TwoElementOutput);
+ .SetShapeFn(TwoElementOutput)
+ .Deprecated(26, "Use TFRecordReaderV2");
REGISTER_OP("TFRecordReaderV2")
.Output("reader_handle: resource")
@@ -337,13 +337,13 @@ REGISTER_OP("LMDBReader")
.SetIsStateful()
.SetShapeFn(TwoElementOutput);
-// TODO(cwhipkey): mark this deprecated in favor of V2.
REGISTER_OP("IdentityReader")
.Output("reader_handle: Ref(string)")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
- .SetShapeFn(TwoElementOutput);
+ .SetShapeFn(TwoElementOutput)
+ .Deprecated(26, "Use IdentityReaderV2");
REGISTER_OP("IdentityReaderV2")
.Output("reader_handle: resource")
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index db53485772..fa40f41bb9 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -176,5 +176,81 @@ REGISTER_OP("TensorListFromTensor")
return Status::OK();
});
+REGISTER_OP("TensorListElementShape")
+ .Input("input_handle: variant")
+ .Output("element_shape: shape_type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data == nullptr) {
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ }
+ c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape)));
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListReserve")
+ .Input("element_shape: shape_type")
+ .Input("num_elements: int32")
+ .Output("handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListGetItem")
+ .Input("input_handle: variant")
+ .Input("index: int32")
+ .Output("item: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ c->set_output(0, element_shape);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListSetItem")
+ .Input("input_handle: variant")
+ .Input("index: int32")
+ .Input("item: element_dtype")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data == nullptr) {
+ c->set_output_handle_shapes_and_types(0, {{c->UnknownShape(), t}});
+ return Status::OK();
+ }
+ const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
+ shape_inference::ShapeHandle s = c->input(2);
+ TF_RETURN_IF_ERROR(c->Merge(s, list_shape_type.shape, &s));
+ c->set_output_handle_shapes_and_types(0, *handle_data);
+ return Status::OK();
+ });
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 82a895a98b..b57206c9c4 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -2738,6 +2738,76 @@ op {
}
}
op {
+ name: "Batch"
+ input_arg {
+ name: "in_tensors"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "batched_tensors"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ attr {
+ name: "num_batch_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_batch_size"
+ type: "int"
+ }
+ attr {
+ name: "batch_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "grad_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "batching_queue"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchCholesky"
input_arg {
name: "input"
@@ -8470,6 +8540,10 @@ op {
s: ""
}
}
+ deprecation {
+ version: 26
+ explanation: "Use FixedLengthRecordReaderV2"
+ }
is_stateful: true
}
op {
@@ -10067,6 +10141,10 @@ op {
s: ""
}
}
+ deprecation {
+ version: 26
+ explanation: "Use IdentityReaderV2"
+ }
is_stateful: true
}
op {
@@ -10789,6 +10867,30 @@ op {
is_stateful: true
}
op {
+ name: "IteratorGetNextSync"
+ input_arg {
+ name: "iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "IteratorSetStatsAggregator"
input_arg {
name: "iterator_handle"
@@ -19629,6 +19731,7 @@ op {
type: DT_UINT16
type: DT_INT32
type: DT_INT64
+ type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
@@ -19663,6 +19766,7 @@ op {
allowed_values {
list {
type: DT_FLOAT
+ type: DT_BFLOAT16
type: DT_HALF
type: DT_DOUBLE
}
@@ -28548,6 +28652,10 @@ op {
s: ""
}
}
+ deprecation {
+ version: 26
+ explanation: "Use TFRecordReaderV2"
+ }
is_stateful: true
}
op {
@@ -28818,6 +28926,10 @@ op {
name: "handle"
type: DT_STRING
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayCloseV3"
+ }
}
op {
name: "TensorArrayCloseV3"
@@ -28997,6 +29109,10 @@ op {
}
}
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayGatherV3"
+ }
}
op {
name: "TensorArrayGatherV3"
@@ -29074,6 +29190,10 @@ op {
name: "source"
type: "string"
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayGradV3"
+ }
is_stateful: true
}
op {
@@ -29183,6 +29303,10 @@ op {
name: "dtype"
type: "type"
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayReadV3"
+ }
}
op {
name: "TensorArrayReadV3"
@@ -29266,6 +29390,10 @@ op {
name: "T"
type: "type"
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayScatterV3"
+ }
}
op {
name: "TensorArrayScatterV3"
@@ -29329,6 +29457,10 @@ op {
name: "size"
type: DT_INT32
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArraySizeV3"
+ }
}
op {
name: "TensorArraySizeV3"
@@ -29404,6 +29536,10 @@ op {
name: "T"
type: "type"
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArraySplitV3"
+ }
}
op {
name: "TensorArraySplitV3"
@@ -29505,6 +29641,10 @@ op {
s: ""
}
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayV3"
+ }
is_stateful: true
}
op {
@@ -29622,6 +29762,10 @@ op {
name: "T"
type: "type"
}
+ deprecation {
+ version: 26
+ explanation: "Use TensorArrayWriteV3"
+ }
}
op {
name: "TensorArrayWriteV3"
@@ -29676,6 +29820,27 @@ op {
is_stateful: true
}
op {
+ name: "TensorListElementShape"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListFromTensor"
input_arg {
name: "tensor"
@@ -29705,6 +29870,25 @@ op {
}
}
op {
+ name: "TensorListGetItem"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "item"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListLength"
input_arg {
name: "input_handle"
@@ -29754,6 +29938,58 @@ op {
}
}
op {
+ name: "TensorListReserve"
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ input_arg {
+ name: "num_elements"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "TensorListSetItem"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "item"
+ type_attr: "element_dtype"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListStack"
input_arg {
name: "input_handle"
@@ -29907,6 +30143,10 @@ op {
s: ""
}
}
+ deprecation {
+ version: 26
+ explanation: "Use TextLineReaderV2"
+ }
is_stateful: true
}
op {
@@ -30290,6 +30530,88 @@ op {
is_stateful: true
}
op {
+ name: "Unbatch"
+ input_arg {
+ name: "batched_tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "unbatched_tensor"
+ type_attr: "T"
+ }
+ attr {
+ name: "timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "UnbatchGrad"
+ input_arg {
+ name: "original_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "batch_index"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "batched_grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index e8d03877c9..6ce9595fb6 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -22,48 +22,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-const char kAddSignCommonDocStr[] = R"doc(
-Update '*var' according to the AddSign update.
-
-m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-update <- (alpha + sign_decay * sign(g) *sign(m)) * g
-variable <- variable - lr_t * update
-
-var: Should be from a Variable().
-m: Should be from a Variable().
-lr: Scaling factor. Must be a scalar.
-sign_decay: Must be a scalar.
-alpha: Must be a scalar.
-beta: Must be a scalar.
-grad: The gradient.
-)doc";
-
-const char kPowerSignCommonDocStr[] = R"doc(
-Update '*var' according to the AddSign update.
-
-m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
-variable <- variable - lr_t * update
-
-var: Should be from a Variable().
-m: Should be from a Variable().
-lr: Scaling factor. Must be a scalar.
-logbase: Must be a scalar.
-sign_decay: Must be a scalar.
-beta: Must be a scalar.
-grad: The gradient.
-)doc";
-
-const char kOutDocStr[] = R"doc(
-out: Same as "var".
-)doc";
-
-const char kLockDocStr[] = R"doc(
-use_locking: If `True`, updating of the var and m tensors is
- protected by a lock; otherwise the behavior is undefined, but may exhibit less
- contention.
-)doc";
-
static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
auto* handle_data = c->input_handle_shapes_and_types(input);
if (handle_data != nullptr && !handle_data->empty() &&
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 6b6be757f6..07aecf8483 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -102,7 +102,7 @@ cc_library(
":http_request",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_internal",
- "@curl//:curl",
+ "@curl",
],
)
@@ -119,7 +119,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
- "@curl//:curl",
+ "@curl",
],
)
diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h
index dd95c18f35..40f16f1044 100644
--- a/tensorflow/core/platform/cloud/gcs_dns_cache.h
+++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
-#define THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#ifndef TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#define TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
#include <random>
@@ -74,4 +74,4 @@ class GcsDnsCache {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#endif // TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 4b30291076..520720372d 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -117,6 +117,9 @@ constexpr char kReadRequestTimeout[] = "GCS_READ_REQUEST_TIMEOUT_SECS";
// The environment variable to configure the overall request timeout for
// upload requests.
constexpr char kWriteRequestTimeout[] = "GCS_WRITE_REQUEST_TIMEOUT_SECS";
+// The environment variable to configure an additional header to send with
+// all requests to GCS (format HEADERNAME:HEADERCONTENT)
+constexpr char kAdditionalRequestHeader[] = "GCS_ADDITIONAL_REQUEST_HEADER";
// TODO: DO NOT use a hardcoded path
Status GetTmpFilename(string* filename) {
@@ -607,6 +610,11 @@ bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*),
return convert(env_value, value);
}
+bool StringPieceIdentity(StringPiece str, StringPiece* value) {
+ *value = str;
+ return true;
+}
+
} // namespace
GcsFileSystem::GcsFileSystem()
@@ -668,6 +676,36 @@ GcsFileSystem::GcsFileSystem()
VLOG(1) << "GCS DNS cache is disabled, because " << kResolveCacheSecs
<< " = 0 (or is not set)";
}
+
+ // Get the additional header
+ StringPiece add_header_contents;
+ if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity,
+ &add_header_contents)) {
+ size_t split = add_header_contents.find(':', 0);
+
+ if (split != StringPiece::npos) {
+ StringPiece header_name = add_header_contents.substr(0, split);
+ StringPiece header_value = add_header_contents.substr(split + 1);
+
+ if (!header_name.empty() && !header_value.empty()) {
+ additional_header_.reset(new std::pair<const string, const string>(
+ header_name.ToString(), header_value.ToString()));
+
+ VLOG(1) << "GCS additional header ENABLED. "
+ << "Name: " << additional_header_->first << ", "
+ << "Value: " << additional_header_->second;
+ } else {
+ LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
+ << add_header_contents;
+ }
+ } else {
+ LOG(ERROR) << "GCS additional header DISABLED. Invalid contents: "
+ << add_header_contents;
+ }
+ } else {
+ VLOG(1) << "GCS additional header DISABLED. No environment variable set.";
+ }
+
// Apply the overrides for request timeouts
uint32 timeout_value;
if (GetEnvVar(kRequestConnectionTimeout, strings::safe_strtou32,
@@ -696,7 +734,8 @@ GcsFileSystem::GcsFileSystem(
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec,
- TimeoutConfig timeouts)
+ TimeoutConfig timeouts,
+ std::pair<const string, const string>* additional_header)
: auth_provider_(std::move(auth_provider)),
http_request_factory_(std::move(http_request_factory)),
file_block_cache_(
@@ -705,7 +744,8 @@ GcsFileSystem::GcsFileSystem(
matching_paths_cache_(new MatchingPathsCache(
matching_paths_cache_max_age, matching_paths_cache_max_entries)),
timeouts_(timeouts),
- initial_retry_delay_usec_(initial_retry_delay_usec) {}
+ initial_retry_delay_usec_(initial_retry_delay_usec),
+ additional_header_(additional_header) {}
Status GcsFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
@@ -1397,6 +1437,11 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
new_request->AddAuthBearerHeader(auth_token);
+ if (additional_header_) {
+ new_request->AddHeader(additional_header_->first,
+ additional_header_->second);
+ }
+
*request = std::move(new_request);
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index adde161a93..2eae39608e 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -17,7 +17,9 @@ limitations under the License.
#define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
#include <string>
+#include <utility>
#include <vector>
+
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/auth_provider.h"
#include "tensorflow/core/platform/cloud/expiring_lru_cache.h"
@@ -44,7 +46,8 @@ class GcsFileSystem : public FileSystem {
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries,
- int64 initial_retry_delay_usec, TimeoutConfig timeouts);
+ int64 initial_retry_delay_usec, TimeoutConfig timeouts,
+ std::pair<const string, const string>* additional_header);
Status NewRandomAccessFile(
const string& filename,
@@ -92,6 +95,12 @@ class GcsFileSystem : public FileSystem {
size_t max_bytes() const { return file_block_cache_->max_bytes(); }
uint64 max_staleness() const { return file_block_cache_->max_staleness(); }
TimeoutConfig timeouts() const { return timeouts_; }
+ string additional_header_name() const {
+ return additional_header_ ? additional_header_->first : "";
+ }
+ string additional_header_value() const {
+ return additional_header_ ? additional_header_->second : "";
+ }
uint64 stat_cache_max_age() const { return stat_cache_->max_age(); }
size_t stat_cache_max_entries() const { return stat_cache_->max_entries(); }
@@ -197,6 +206,9 @@ class GcsFileSystem : public FileSystem {
/// The initial delay for exponential backoffs when retrying failed calls.
const int64 initial_retry_delay_usec_ = 1000000L;
+ // Additional header material to be transmitted with all GCS requests
+ std::unique_ptr<std::pair<const string, const string>> additional_header_;
+
TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem);
};
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 772aec5273..d452074ce3 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -53,7 +53,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -93,7 +94,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_differentN) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -137,15 +139,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) {
"Range: 18-26\n"
"Timeouts: 5 1 20\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 9 /* block size */, 18 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
@@ -211,15 +213,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
"Range: 0-8\n"
"Timeouts: 5 1 20\n",
"012345678")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 9 /* block size */, 18 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
@@ -252,15 +254,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
"Range: 8-15\n"
"Timeouts: 5 1 20\n",
"89abcdef")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 8 /* block size */, 16 /* max bytes */,
- 3600 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 8 /* block size */, 16 /* max bytes */, 3600 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
// There should only be two HTTP requests issued to GCS even though we iterate
@@ -294,15 +296,15 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* read ahead bytes */, 0 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -344,7 +346,8 @@ TEST(GcsFileSystemTest, NewWritableFile) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
// Read from the file first, to fill the block cache.
std::unique_ptr<RandomAccessFile> rfile;
@@ -418,7 +421,8 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -465,15 +469,15 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) {
"Range: 0-7\n"
"Timeouts: 5 1 20\n",
"01234567")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 8 /* block size */, 8 /* max bytes */,
- 3600 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 8 /* block size */, 8 /* max bytes */, 3600 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// Pull the file's first block into the cache. This will trigger the first
// HTTP request to GCS.
std::unique_ptr<RandomAccessFile> rfile;
@@ -557,7 +561,8 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 2 /* initial retry delay */, kTestTimeoutConfig);
+ 2 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -612,7 +617,8 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -641,7 +647,8 @@ TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -676,15 +683,15 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
"Range: 0-31\n"
"Timeouts: 5 1 20\n",
"01234567")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 32 /* block size */, 32 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 32 /* block size */, 32 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// Create an appendable file. This should read the file from GCS, and pull its
// contents into the block cache.
@@ -717,7 +724,8 @@ TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -748,7 +756,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -767,7 +776,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -789,7 +799,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
}
@@ -817,7 +828,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
}
@@ -841,7 +853,8 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket1"));
TF_EXPECT_OK(fs.FileExists("gs://bucket1/"));
@@ -869,7 +882,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::NOT_FOUND,
fs.FileExists("gs://bucket/path/file1.txt").code());
@@ -894,7 +908,8 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.FileExists("gs://bucket2/").code());
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -924,15 +939,15 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subfolder/\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// The stat cache will ensure that repeated lookups don't trigger additional
// HTTP requests.
@@ -957,7 +972,8 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -983,7 +999,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1010,7 +1027,8 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1036,7 +1054,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1059,7 +1078,8 @@ TEST(GcsFileSystemTest, GetChildren_Root) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children));
@@ -1082,7 +1102,8 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1121,7 +1142,8 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1146,7 +1168,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(
@@ -1172,7 +1195,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result));
@@ -1199,7 +1223,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result));
@@ -1223,7 +1248,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result));
@@ -1247,7 +1273,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result));
@@ -1263,7 +1290,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::vector<string> result;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1295,7 +1323,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
3600 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.GetMatchingPaths on these patterns should not lead to
// any additional HTTP requests to GCS.
@@ -1336,7 +1365,8 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
3600 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
// This loop should trigger the first HTTP request to GCS.
for (int i = 0; i < 10; i++) {
@@ -1377,15 +1407,15 @@ TEST(GcsFileSystemTest, DeleteFile) {
"Range: 0-15\n"
"Timeouts: 5 1 20\n",
"76543210")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 16 /* block size */, 16 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// Do an initial read of the file to load its contents into the block cache.
char scratch[100];
@@ -1411,7 +1441,8 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.DeleteFile("gs://bucket/").code());
@@ -1431,7 +1462,8 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1458,7 +1490,8 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1476,7 +1509,8 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket"));
}
@@ -1496,7 +1530,8 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.DeleteDir("gs://bucket/path/").code());
@@ -1517,7 +1552,8 @@ TEST(GcsFileSystemTest, GetFileSize) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -1533,7 +1569,8 @@ TEST(GcsFileSystemTest, GetFileSize_NoObjectName) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
uint64 size;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1617,7 +1654,8 @@ TEST(GcsFileSystemTest, RenameFile_Folder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/"));
}
@@ -1680,15 +1718,15 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
"Range: 0-15\n"
"Timeouts: 5 1 20\n",
"fedcba98")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 16 /* block size */, 64 /* max bytes */,
- 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 16 /* block size */, 64 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// Do an initial read of the source and destination files to load their
// contents into the block cache.
char scratch[100];
@@ -1761,7 +1799,8 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));
@@ -1801,7 +1840,8 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(
errors::Code::UNIMPLEMENTED,
@@ -1824,7 +1864,8 @@ TEST(GcsFileSystemTest, Stat_Object) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
@@ -1856,7 +1897,8 @@ TEST(GcsFileSystemTest, Stat_Folder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
@@ -1887,7 +1929,8 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code());
@@ -1906,7 +1949,8 @@ TEST(GcsFileSystemTest, Stat_Bucket) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat));
@@ -1928,7 +1972,8 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code());
@@ -1957,15 +2002,15 @@ TEST(GcsFileSystemTest, Stat_Cache) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"subfolder/\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// Repeated calls to fs.Stat on these paths should not lead to any additional
// HTTP requests to GCS.
@@ -1998,15 +2043,15 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */,
- 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
// There should be a single HTTP request to GCS for fs.Stat in this loop.
for (int i = 0; i < 10; i++) {
FileStatistics stat;
@@ -2048,7 +2093,8 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2077,7 +2123,8 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2106,7 +2153,8 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/"));
@@ -2131,7 +2179,8 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/"));
@@ -2150,7 +2199,8 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code());
}
@@ -2190,7 +2240,8 @@ TEST(GcsFileSystemTest, CreateDir_Folder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/"));
@@ -2215,7 +2266,8 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket"));
@@ -2285,7 +2337,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -2376,7 +2429,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -2409,7 +2463,8 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig);
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
EXPECT_EQ(error::Code::NOT_FOUND,
@@ -2420,6 +2475,64 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
EXPECT_EQ(1, undeleted_dirs);
}
+TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
+ GcsFileSystem fs1;
+ EXPECT_EQ("", fs1.additional_header_name());
+ EXPECT_EQ("", fs1.additional_header_value());
+
+ setenv("GCS_ADDITIONAL_REQUEST_HEADER",
+ "X-Add-Header:My Additional Header Value", 1);
+ GcsFileSystem fs2;
+ EXPECT_EQ("X-Add-Header", fs2.additional_header_name());
+ EXPECT_EQ("My Additional Header Value", fs2.additional_header_value());
+
+ setenv("GCS_ADDITIONAL_REQUEST_HEADER", "Someinvalidheadervalue", 1);
+ GcsFileSystem fs3;
+ EXPECT_EQ("", fs3.additional_header_name());
+ EXPECT_EQ("", fs3.additional_header_value());
+
+ setenv("GCS_ADDITIONAL_REQUEST_HEADER", ":thisisinvalid", 1);
+ GcsFileSystem fs4;
+ EXPECT_EQ("", fs4.additional_header_name());
+ EXPECT_EQ("", fs4.additional_header_value());
+
+ setenv("GCS_ADDITIONAL_REQUEST_HEADER", "soisthis:", 1);
+ GcsFileSystem fs5;
+ EXPECT_EQ("", fs5.additional_header_name());
+ EXPECT_EQ("", fs5.additional_header_value());
+
+ setenv("GCS_ADDITIONAL_REQUEST_HEADER", "a:b", 1);
+ GcsFileSystem fs6;
+ EXPECT_EQ("a", fs6.additional_header_name());
+ EXPECT_EQ("b", fs6.additional_header_value());
+
+ auto* add_header = new std::pair<const string, const string>(
+ "mynewheader", "newheadercontents");
+
+ std::vector<HttpRequest*> requests(
+ {// IsDirectory is checking whether there are children objects.
+ new FakeHttpRequest("Uri: https://www.googleapis.com/fake\n"
+ "Auth Token: fake_token\n"
+ "Header mynewheader: newheadercontents\n"
+ "Header Hello: world\n",
+ "{}")});
+ GcsFileSystem fs7(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, add_header /* gcs additional header */);
+
+ std::unique_ptr<HttpRequest> request;
+ TF_EXPECT_OK(fs7.CreateHttpRequest(&request));
+ request->SetUri("https://www.googleapis.com/fake");
+ request->AddHeader("Hello", "world");
+ TF_EXPECT_OK(request->Send());
+}
+
TEST(GcsFileSystemTest, OverrideCacheParameters) {
// Verify defaults are propagated correctly.
GcsFileSystem fs1;
@@ -2485,7 +2598,8 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig);
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs.CreateHttpRequest(&request));
diff --git a/tensorflow/core/platform/cloud/oauth_client.h b/tensorflow/core/platform/cloud/oauth_client.h
index 1614c7b315..519d69acf9 100644
--- a/tensorflow/core/platform/cloud/oauth_client.h
+++ b/tensorflow/core/platform/cloud/oauth_client.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
#include <memory>
#include "include/json/json.h"
@@ -59,4 +59,4 @@ class OAuthClient {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h
index 99ab216e97..546b8d1c4a 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.h
+++ b/tensorflow/core/platform/cloud/retrying_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
#include <functional>
#include "tensorflow/core/lib/core/status.h"
@@ -47,4 +47,4 @@ class RetryingUtils {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_UTILS_H_
diff --git a/tensorflow/core/platform/cloud/time_util.h b/tensorflow/core/platform/cloud/time_util.h
index b1bb7f1119..d6d4bc499f 100644
--- a/tensorflow/core/platform/cloud/time_util.h
+++ b/tensorflow/core/platform/cloud/time_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
@@ -26,4 +26,4 @@ Status ParseRfc3339Time(const string& time, int64* mtime_nsec);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_
diff --git a/tensorflow/core/platform/cuda_libdevice_path.h b/tensorflow/core/platform/cuda_libdevice_path.h
index 601d0db6d4..6ef565ecd3 100644
--- a/tensorflow/core/platform/cuda_libdevice_path.h
+++ b/tensorflow/core/platform/cuda_libdevice_path.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
+#define TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
#include "tensorflow/core/platform/types.h"
@@ -29,4 +29,4 @@ string LibdeviceRoot();
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CUDA_LIBDEVICE_PATH_H_
diff --git a/tensorflow/core/platform/cupti_wrapper.h b/tensorflow/core/platform/cupti_wrapper.h
index c909dcd35b..9a17ab60c0 100644
--- a/tensorflow/core/platform/cupti_wrapper.h
+++ b/tensorflow/core/platform/cupti_wrapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
+#define TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
#include "tensorflow/core/platform/platform.h"
@@ -24,4 +24,4 @@ limitations under the License.
#include "tensorflow/core/platform/default/gpu/cupti_wrapper.h"
#endif
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CUPTI_WRAPPER_H_
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index f2fadb4558..2cd607edbe 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/core:protos_cc",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
- "@fft2d//:fft2d",
+ "@fft2d",
"@highwayhash//:sip_hash",
"@png_archive//:png",
],
@@ -140,7 +140,7 @@ cc_library(
name = "jpeg",
copts = tf_copts(),
deps = [
- "@jpeg//:jpeg",
+ "@jpeg",
],
)
diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.h b/tensorflow/core/platform/default/gpu/cupti_wrapper.h
index 38e01cefad..acd889e474 100644
--- a/tensorflow/core/platform/default/gpu/cupti_wrapper.h
+++ b/tensorflow/core/platform/default/gpu/cupti_wrapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
#if GOOGLE_CUDA
@@ -76,4 +76,4 @@ class CuptiWrapper {
#endif // GOOGLE_CUDA
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CUPTI_WRAPPER_H_
diff --git a/tensorflow/core/platform/demangle.h b/tensorflow/core/platform/demangle.h
index c2def217a1..ce33be2e68 100644
--- a/tensorflow/core/platform/demangle.h
+++ b/tensorflow/core/platform/demangle.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
+#define TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
#include "tensorflow/core/platform/types.h"
@@ -28,4 +28,4 @@ string Demangle(const char* mangled);
} // namespace port
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_
diff --git a/tensorflow/core/platform/file_statistics.h b/tensorflow/core/platform/file_statistics.h
index 7629db6ef9..9e3489b1ad 100644
--- a/tensorflow/core/platform/file_statistics.h
+++ b/tensorflow/core/platform/file_statistics.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
+#define TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
#include "tensorflow/core/platform/types.h"
@@ -36,4 +36,4 @@ struct FileStatistics {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h
index 447e83158a..5f2b222622 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.h
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
+#define TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
#include "tensorflow/core/platform/env.h"
@@ -70,4 +70,4 @@ class HadoopFileSystem : public FileSystem {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.h b/tensorflow/core/platform/profile_utils/cpu_utils.h
index 5d215b4804..e95843b80a 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.h
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.h
@@ -42,7 +42,7 @@ namespace profile_utils {
class CpuUtils {
public:
// Constant for invalid frequency.
- // This value is returned when the furequency is not obtained somehow.
+ // This value is returned when the frequency is not obtained somehow.
static constexpr int64 INVALID_FREQUENCY = -1;
static constexpr uint64 DUMMY_CYCLE_CLOCK = 1;
@@ -103,7 +103,7 @@ class CpuUtils {
static int64 GetCycleCounterFrequency();
#endif
- // Return micro secound per each clock
+ // Return micro second per each clock
// As this method caches the cpu frequency internally,
// the first call will incur overhead, but not subsequent calls.
static double GetMicroSecPerClock();
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index 2cd5f877c9..3a0ad2e9bd 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -45,8 +45,8 @@ tf_cc_binary(
linkshared = 1,
deps = [
"//tensorflow/core:framework_headers_lib",
- "@aws//:aws",
- "@curl//:curl",
+ "@aws",
+ "@curl",
"@protobuf_archive//:protobuf_headers",
],
)
@@ -62,7 +62,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@aws//:aws",
+ "@aws",
"@boringssl//:crypto",
],
alwayslink = 1,
@@ -79,7 +79,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@aws//:aws",
+ "@aws",
],
alwayslink = 1,
)
@@ -97,7 +97,7 @@ cc_library(
":s3_crypto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@aws//:aws",
+ "@aws",
],
alwayslink = 1,
)
@@ -117,6 +117,6 @@ tf_cc_test(
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "@aws//:aws",
+ "@aws",
],
)
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 58ea315670..1e89fa77c1 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -14,14 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/s3/aws_logging.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h>
+#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
+#include <aws/core/utils/StringUtils.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/S3Errors.h>
#include <aws/s3/model/CopyObjectRequest.h>
@@ -54,13 +57,37 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
cfg.endpointOverride = Aws::String(endpoint);
}
const char* region = getenv("AWS_REGION");
+ if (!region) {
+ // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
+ region = getenv("S3_REGION");
+ }
if (region) {
cfg.region = Aws::String(region);
} else {
- // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
- const char* region = getenv("S3_REGION");
- if (region) {
- cfg.region = Aws::String(region);
+ // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
+ // is set with a truthy value.
+ const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
+ string load_config =
+ load_config_env ? str_util::Lowercase(load_config_env) : "";
+ if (load_config == "true" || load_config == "1") {
+ Aws::String config_file;
+ // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
+ const char* config_file_env = getenv("AWS_CONFIG_FILE");
+ if (config_file_env) {
+ config_file = config_file_env;
+ } else {
+ const char* home_env = getenv("HOME");
+ if (home_env) {
+ config_file = home_env;
+ config_file += "/.aws/config";
+ }
+ }
+ Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
+ loader.Load();
+ auto profiles = loader.GetProfiles();
+ if (!profiles["default"].GetRegion().empty()) {
+ cfg.region = profiles["default"].GetRegion();
+ }
}
}
const char* use_https = getenv("S3_USE_HTTPS");
@@ -102,6 +129,16 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
return cfg;
};
+
+void ShutdownClient(Aws::S3::S3Client *s3_client) {
+ if (s3_client != nullptr) {
+ delete s3_client;
+ Aws::SDKOptions options;
+ Aws::ShutdownAPI(options);
+ AWSLogSystem::ShutdownAWSLogging();
+ }
+}
+
Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
string* object) {
if (!bucket || !object) {
@@ -129,12 +166,12 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
class S3RandomAccessFile : public RandomAccessFile {
public:
- S3RandomAccessFile(const string& bucket, const string& object)
- : bucket_(bucket), object_(object) {}
+ S3RandomAccessFile(const string& bucket, const string& object,
+ std::shared_ptr<Aws::S3::S3Client> s3_client)
+ : bucket_(bucket), object_(object), s3_client_(s3_client) {}
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
Aws::S3::Model::GetObjectRequest getObjectRequest;
getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
@@ -142,7 +179,7 @@ class S3RandomAccessFile : public RandomAccessFile {
getObjectRequest.SetResponseStreamFactory([]() {
return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
});
- auto getObjectOutcome = s3Client.GetObject(getObjectRequest);
+ auto getObjectOutcome = this->s3_client_->GetObject(getObjectRequest);
if (!getObjectOutcome.IsSuccess()) {
n = 0;
*result = StringPiece(scratch, n);
@@ -160,13 +197,16 @@ class S3RandomAccessFile : public RandomAccessFile {
private:
string bucket_;
string object_;
+ std::shared_ptr<Aws::S3::S3Client> s3_client_;
};
class S3WritableFile : public WritableFile {
public:
- S3WritableFile(const string& bucket, const string& object)
+ S3WritableFile(const string& bucket, const string& object,
+ std::shared_ptr<Aws::S3::S3Client> s3_client)
: bucket_(bucket),
object_(object),
+ s3_client_(s3_client),
sync_needed_(true),
outfile_(Aws::MakeShared<Aws::Utils::TempFile>(
kS3FileSystemAllocationTag, "/tmp/s3_filesystem_XXXXXX",
@@ -205,17 +245,13 @@ class S3WritableFile : public WritableFile {
if (!sync_needed_) {
return Status::OK();
}
- Aws::Client::ClientConfiguration clientConfig = GetDefaultClientConfig();
- clientConfig.connectTimeoutMs = 300000;
- clientConfig.requestTimeoutMs = 600000;
- Aws::S3::S3Client s3Client(clientConfig);
Aws::S3::Model::PutObjectRequest putObjectRequest;
putObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
long offset = outfile_->tellp();
outfile_->seekg(0);
putObjectRequest.SetBody(outfile_);
putObjectRequest.SetContentLength(offset);
- auto putObjectOutcome = s3Client.PutObject(putObjectRequest);
+ auto putObjectOutcome = this->s3_client_->PutObject(putObjectRequest);
outfile_->clear();
outfile_->seekp(offset);
if (!putObjectOutcome.IsSuccess()) {
@@ -230,6 +266,7 @@ class S3WritableFile : public WritableFile {
private:
string bucket_;
string object_;
+ std::shared_ptr<Aws::S3::S3Client> s3_client_;
bool sync_needed_;
std::shared_ptr<Aws::Utils::TempFile> outfile_;
};
@@ -248,31 +285,39 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
} // namespace
-S3FileSystem::S3FileSystem() {
- AWSLogSystem::InitializeAWSLogging();
-
- Aws::SDKOptions options;
- options.cryptoOptions.sha256Factory_create_fn = []() {
- return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
- };
- options.cryptoOptions.sha256HMACFactory_create_fn = []() {
- return Aws::MakeShared<S3SHA256HmacFactory>(S3CryptoAllocationTag);
- };
- Aws::InitAPI(options);
-}
+S3FileSystem::S3FileSystem() :
+ s3_client_(nullptr, ShutdownClient), client_lock_() {}
+
+S3FileSystem::~S3FileSystem() {}
+
+// Initializes s3_client_, if needed, and returns it.
+std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
+ std::lock_guard<mutex> lock(this->client_lock_);
+
+ if (this->s3_client_.get() == nullptr) {
+ AWSLogSystem::InitializeAWSLogging();
-S3FileSystem::~S3FileSystem() {
- Aws::SDKOptions options;
- Aws::ShutdownAPI(options);
+ Aws::SDKOptions options;
+ options.cryptoOptions.sha256Factory_create_fn = []() {
+ return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
+ };
+ options.cryptoOptions.sha256HMACFactory_create_fn = []() {
+ return Aws::MakeShared<S3SHA256HmacFactory>(S3CryptoAllocationTag);
+ };
+ Aws::InitAPI(options);
- AWSLogSystem::ShutdownAWSLogging();
+ this->s3_client_ = std::shared_ptr<Aws::S3::S3Client>(
+ new Aws::S3::S3Client(GetDefaultClientConfig()));
+ }
+
+ return this->s3_client_;
}
Status S3FileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
- result->reset(new S3RandomAccessFile(bucket, object));
+ result->reset(new S3RandomAccessFile(bucket, object, this->GetS3Client()));
return Status::OK();
}
@@ -280,7 +325,7 @@ Status S3FileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
- result->reset(new S3WritableFile(bucket, object));
+ result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
return Status::OK();
}
@@ -295,7 +340,7 @@ Status S3FileSystem::NewAppendableFile(const string& fname,
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
- result->reset(new S3WritableFile(bucket, object));
+ result->reset(new S3WritableFile(bucket, object, this->GetS3Client()));
while (true) {
status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
@@ -346,7 +391,6 @@ Status S3FileSystem::GetChildren(const string& dir,
prefix.push_back('/');
}
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
Aws::S3::Model::ListObjectsRequest listObjectsRequest;
listObjectsRequest.WithBucket(bucket.c_str())
.WithPrefix(prefix.c_str())
@@ -357,7 +401,7 @@ Status S3FileSystem::GetChildren(const string& dir,
Aws::S3::Model::ListObjectsResult listObjectsResult;
do {
- auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest);
if (!listObjectsOutcome.IsSuccess()) {
string error = strings::StrCat(
listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ",
@@ -391,11 +435,10 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, true, &bucket, &object));
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
if (object.empty()) {
Aws::S3::Model::HeadBucketRequest headBucketRequest;
headBucketRequest.WithBucket(bucket.c_str());
- auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest);
+ auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
if (!headBucketOutcome.IsSuccess()) {
string error = strings::StrCat(
headBucketOutcome.GetError().GetExceptionName().c_str(), ": ",
@@ -413,7 +456,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
headObjectRequest.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
- auto headObjectOutcome = s3Client.HeadObject(headObjectRequest);
+ auto headObjectOutcome = this->GetS3Client()->HeadObject(headObjectRequest);
if (headObjectOutcome.IsSuccess()) {
stats->length = headObjectOutcome.GetResult().GetContentLength();
stats->is_directory = 0;
@@ -431,7 +474,7 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
.WithMaxKeys(1);
listObjectsRequest.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
- auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest);
if (listObjectsOutcome.IsSuccess()) {
if (listObjectsOutcome.GetResult().GetContents().size() > 0) {
stats->length = 0;
@@ -449,11 +492,11 @@ Status S3FileSystem::DeleteFile(const string& fname) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
- auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest);
+ auto deleteObjectOutcome =
+ this->GetS3Client()->DeleteObject(deleteObjectRequest);
if (!deleteObjectOutcome.IsSuccess()) {
string error = strings::StrCat(
deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
@@ -468,10 +511,9 @@ Status S3FileSystem::CreateDir(const string& dirname) {
TF_RETURN_IF_ERROR(ParseS3Path(dirname, true, &bucket, &object));
if (object.empty()) {
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
Aws::S3::Model::HeadBucketRequest headBucketRequest;
headBucketRequest.WithBucket(bucket.c_str());
- auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest);
+ auto headBucketOutcome = this->GetS3Client()->HeadBucket(headBucketRequest);
if (!headBucketOutcome.IsSuccess()) {
return errors::NotFound("The bucket ", bucket, " was not found.");
}
@@ -491,7 +533,6 @@ Status S3FileSystem::DeleteDir(const string& dirname) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseS3Path(dirname, false, &bucket, &object));
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
string prefix = object;
if (prefix.back() != '/') {
prefix.push_back('/');
@@ -502,7 +543,7 @@ Status S3FileSystem::DeleteDir(const string& dirname) {
.WithMaxKeys(2);
listObjectsRequest.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
- auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest);
if (listObjectsOutcome.IsSuccess()) {
auto contents = listObjectsOutcome.GetResult().GetContents();
if (contents.size() > 1 ||
@@ -542,8 +583,6 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
}
}
- Aws::S3::S3Client s3Client(GetDefaultClientConfig());
-
Aws::S3::Model::CopyObjectRequest copyObjectRequest;
Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
@@ -556,7 +595,7 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
Aws::S3::Model::ListObjectsResult listObjectsResult;
do {
- auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest);
if (!listObjectsOutcome.IsSuccess()) {
string error = strings::StrCat(
listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ",
@@ -569,13 +608,14 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
Aws::String src_key = object.GetKey();
Aws::String target_key = src_key;
target_key.replace(0, src_object.length(), target_object.c_str());
- Aws::String source = Aws::String(src_bucket.c_str()) + "/" + src_key;
+ Aws::String source = Aws::String(src_bucket.c_str()) + "/"
+ + Aws::Utils::StringUtils::URLEncode(src_key.c_str());
copyObjectRequest.SetBucket(target_bucket.c_str());
copyObjectRequest.SetKey(target_key);
copyObjectRequest.SetCopySource(source);
- auto copyObjectOutcome = s3Client.CopyObject(copyObjectRequest);
+ auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest);
if (!copyObjectOutcome.IsSuccess()) {
string error = strings::StrCat(
copyObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
@@ -586,7 +626,8 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) {
deleteObjectRequest.SetBucket(src_bucket.c_str());
deleteObjectRequest.SetKey(src_key.c_str());
- auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest);
+ auto deleteObjectOutcome =
+ this->GetS3Client()->DeleteObject(deleteObjectRequest);
if (!deleteObjectOutcome.IsSuccess()) {
string error = strings::StrCat(
deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h
index 31ba3cecc5..168b8007f3 100644
--- a/tensorflow/core/platform/s3/s3_file_system.h
+++ b/tensorflow/core/platform/s3/s3_file_system.h
@@ -16,7 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_
#define TENSORFLOW_CONTRIB_S3_S3_FILE_SYSTEM_H_
+#include <aws/s3/S3Client.h>
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
@@ -53,6 +55,13 @@ class S3FileSystem : public FileSystem {
Status GetFileSize(const string& fname, uint64* size) override;
Status RenameFile(const string& src, const string& target) override;
+ private:
+ // Returns the member S3 client, initializing as-needed.
+ std::shared_ptr<Aws::S3::S3Client> GetS3Client();
+
+ std::shared_ptr<Aws::S3::S3Client> s3_client_;
+ // Lock held when checking for s3_client_ initialization.
+ mutex client_lock_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_file_system_test.cc b/tensorflow/core/platform/s3/s3_file_system_test.cc
index 0b42f5fcec..d4411d9865 100644
--- a/tensorflow/core/platform/s3/s3_file_system_test.cc
+++ b/tensorflow/core/platform/s3/s3_file_system_test.cc
@@ -130,6 +130,8 @@ TEST_F(S3FileSystemTest, NewReadOnlyMemoryRegionFromFile) {
TEST_F(S3FileSystemTest, FileExists) {
const string fname = TmpDir("FileExists");
+ // Ensure the file doesn't yet exist.
+ TF_ASSERT_OK(s3fs.DeleteFile(fname));
EXPECT_EQ(error::Code::NOT_FOUND, s3fs.FileExists(fname).code());
TF_ASSERT_OK(WriteString(fname, "test"));
TF_EXPECT_OK(s3fs.FileExists(fname));
diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h
index d36c82c9ba..a52970fdaa 100644
--- a/tensorflow/core/platform/stacktrace_handler.h
+++ b/tensorflow/core/platform/stacktrace_handler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#define TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
namespace tensorflow {
namespace testing {
@@ -25,4 +25,4 @@ void InstallStacktraceHandler();
} // namespace testing
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index 5fbfc62e74..35d9993018 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -38,7 +38,7 @@ tf_cc_binary(
"//tensorflow/core/profiler/internal:tfprof_stats",
"//tensorflow/core/profiler/internal:tfprof_utils",
"//tensorflow/core/profiler/internal/advisor:tfprof_advisor",
- "@linenoise//:linenoise",
+ "@linenoise",
],
)
diff --git a/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h b/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h
index c6544fe0b0..25766668d8 100644
--- a/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h
+++ b/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This checker checks the accelerator's utilization.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -106,4 +106,4 @@ class AcceleratorUtilizationChecker : public Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/checker.h b/tensorflow/core/profiler/internal/advisor/checker.h
index 4b5ebcf9e8..5d7da39e6b 100644
--- a/tensorflow/core/profiler/internal/advisor/checker.h
+++ b/tensorflow/core/profiler/internal/advisor/checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/profiler/internal/tfprof_stats.h"
@@ -49,4 +49,4 @@ class Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
index 145782c7bd..f5ac5c9c5a 100644
--- a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
+++ b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This checker checks the most expensive operations.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -137,4 +137,4 @@ class ExpensiveOperationChecker : public Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h b/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h
index ec52741b19..6fc16cf903 100644
--- a/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h
+++ b/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
#include "tensorflow/core/profiler/internal/tfprof_utils.h"
#include "tensorflow/core/profiler/tfprof_options.pb.h"
@@ -31,4 +31,4 @@ AdviceProto RunInternalCheckers(const AdvisorOptionsProto& options,
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/operation_checker.h b/tensorflow/core/profiler/internal/advisor/operation_checker.h
index f0bd72fa40..6c1d5cd670 100644
--- a/tensorflow/core/profiler/internal/advisor/operation_checker.h
+++ b/tensorflow/core/profiler/internal/advisor/operation_checker.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// This checker checks common wrong configurations of operations.
//
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -74,4 +74,4 @@ class OperationChecker : public Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
index 42bd6d5438..270662bd4a 100644
--- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
+++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
#include "tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h"
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -78,4 +78,4 @@ class Advisor {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
diff --git a/tensorflow/core/profiler/internal/print_model_analysis.h b/tensorflow/core/profiler/internal/print_model_analysis.h
index 90166aa7d5..29666ab936 100644
--- a/tensorflow/core/profiler/internal/print_model_analysis.h
+++ b/tensorflow/core/profiler/internal/print_model_analysis.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
#include <string>
@@ -63,4 +63,4 @@ string PrintModelAnalysis(const string* graph, const string* run_meta,
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_code.h b/tensorflow/core/profiler/internal/tfprof_code.h
index bcbdc1b48c..38395f967c 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.h
+++ b/tensorflow/core/profiler/internal/tfprof_code.h
@@ -16,8 +16,8 @@ limitations under the License.
// Build a tree structure based on the TensorFlow model's python code stacks.
// Stats are aggregated from descendants to ancestors.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
#include <map>
#include <memory>
@@ -94,4 +94,4 @@ class TFCode : public TFMultiShow {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_constants.h b/tensorflow/core/profiler/internal/tfprof_constants.h
index 6a4eaaa890..d4a47931a2 100644
--- a/tensorflow/core/profiler/internal/tfprof_constants.h
+++ b/tensorflow/core/profiler/internal/tfprof_constants.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
namespace tensorflow {
namespace tfprof {
@@ -34,4 +34,4 @@ static const char* const kCkptVarType = "_checkpoint_variables";
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_graph.h b/tensorflow/core/profiler/internal/tfprof_graph.h
index f7eef9c835..356a459a65 100644
--- a/tensorflow/core/profiler/internal/tfprof_graph.h
+++ b/tensorflow/core/profiler/internal/tfprof_graph.h
@@ -16,8 +16,8 @@ limitations under the License.
// Build a graph structure based on op inputs/outputs. The graph is a directed
// acyclic graph pointing *from outputs to inputs*.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
#include <deque>
#include <map>
@@ -86,4 +86,4 @@ class TFGraph : public TFShow {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h
index 255a0987e6..0a97b1cb0f 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.h
+++ b/tensorflow/core/profiler/internal/tfprof_node.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
#include <map>
#include <set>
@@ -915,4 +915,4 @@ bool IsCanonicalDevice(const string& device);
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_node_show.h b/tensorflow/core/profiler/internal/tfprof_node_show.h
index ca6f9bca5e..517da67d74 100644
--- a/tensorflow/core/profiler/internal/tfprof_node_show.h
+++ b/tensorflow/core/profiler/internal/tfprof_node_show.h
@@ -21,8 +21,8 @@ limitations under the License.
// ScopeNode and GraphNode each maps to one TFGraphNode.
// CodeNode and OpNode each maps to one TFMultiGraphNode.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
#include <algorithm>
#include <string>
@@ -156,4 +156,4 @@ class OpNode : public ShowMultiNode {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_op.h b/tensorflow/core/profiler/internal/tfprof_op.h
index fcc5e68f47..fe1c3b2ae8 100644
--- a/tensorflow/core/profiler/internal/tfprof_op.h
+++ b/tensorflow/core/profiler/internal/tfprof_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// Build a flat structure of ops.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
#include <deque>
#include <map>
@@ -76,4 +76,4 @@ class TFOp : public TFMultiShow {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_scope.h b/tensorflow/core/profiler/internal/tfprof_scope.h
index bb847c0866..235dfde803 100644
--- a/tensorflow/core/profiler/internal/tfprof_scope.h
+++ b/tensorflow/core/profiler/internal/tfprof_scope.h
@@ -17,8 +17,8 @@ limitations under the License.
// For example, 'name1/name2' is a child of 'name1'.
// Stats are aggregated from descendants to ancestors.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
#include <map>
#include <memory>
@@ -74,4 +74,4 @@ class TFScope : public TFShow {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_show.h b/tensorflow/core/profiler/internal/tfprof_show.h
index 2067ea3b73..4d6de06070 100644
--- a/tensorflow/core/profiler/internal/tfprof_show.h
+++ b/tensorflow/core/profiler/internal/tfprof_show.h
@@ -15,8 +15,8 @@ limitations under the License.
// Parent class and utilities for tfprof_graph and tfprof_scope.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
#include <algorithm>
#include <string>
@@ -151,4 +151,4 @@ string FormatAcceleratorExecTime(const T* node, const Options& opts) {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_show_multi.h b/tensorflow/core/profiler/internal/tfprof_show_multi.h
index ac0ada0449..2a2208d8e7 100644
--- a/tensorflow/core/profiler/internal/tfprof_show_multi.h
+++ b/tensorflow/core/profiler/internal/tfprof_show_multi.h
@@ -15,8 +15,8 @@ limitations under the License.
// Parent class and utilities for tfprof_code.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
#include <algorithm>
#include <string>
@@ -127,4 +127,4 @@ class TFMultiShow {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_stats.h b/tensorflow/core/profiler/internal/tfprof_stats.h
index d78abda588..db148c936c 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats.h
+++ b/tensorflow/core/profiler/internal/tfprof_stats.h
@@ -20,8 +20,8 @@ limitations under the License.
// 3. Accept command and options to selectively aggregate stats for analysis
// and print out the results.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
#include <map>
#include <memory>
@@ -83,7 +83,7 @@ class TFStats {
const MultiGraphNodeProto& ShowMultiGraphNode(const string& cmd,
const Options& opts) const;
- // A a (partial) graph to existing graph.
+ // Add a (partial) graph to existing graph.
void AddGraph(std::unique_ptr<GraphDef> graph);
// Add a step of run time meta data.
@@ -118,11 +118,11 @@ class TFStats {
MultiGraphNodeProto empty_multi_graph_node_;
std::map<int64, string> id_to_string_;
- // Graph nodes covered by RunMetdata, that is traced with run time stats.
+ // Graph nodes covered by RunMetadata, that is traced with run time stats.
std::set<int64> covered_nodes_;
};
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_tensor.h b/tensorflow/core/profiler/internal/tfprof_tensor.h
index 9f72e081c9..7a08857720 100644
--- a/tensorflow/core/profiler/internal/tfprof_tensor.h
+++ b/tensorflow/core/profiler/internal/tfprof_tensor.h
@@ -19,8 +19,8 @@ limitations under the License.
// is not supported by TensorFlow CheckPointReader library, though it is
// supported in current code.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
#include <typeinfo>
@@ -173,4 +173,4 @@ class TFProfTensor {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h
index b8174cdecb..651ad3f0c1 100644
--- a/tensorflow/core/profiler/internal/tfprof_timeline.h
+++ b/tensorflow/core/profiler/internal/tfprof_timeline.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
#include "include/json/json.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -178,7 +178,6 @@ class Timeline {
int64 step_;
const string outfile_;
int64 next_pid_ = 0;
- int64 allocator_pid_ = -1;
MemoryTracker mem_tracker_;
ChromeTraceFormatter chrome_formatter_;
std::map<string, int64> device_pids_;
@@ -191,4 +190,4 @@ class Timeline {
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_utils.cc b/tensorflow/core/profiler/internal/tfprof_utils.cc
index 2813bb46fa..7712ebd926 100644
--- a/tensorflow/core/profiler/internal/tfprof_utils.cc
+++ b/tensorflow/core/profiler/internal/tfprof_utils.cc
@@ -355,9 +355,6 @@ static const char* const kOpTypes =
static const char* const kScope =
"scope: The nodes in the model graph are organized by their names, which "
"is hierarchical like filesystem.";
-static const char* const kGraph =
- "graph: The nodes in the model graph are organized by their operation "
- "input and output.";
static const char* const kCode =
"code: When python trace is available, the nodes are python lines and "
"their are organized by the python call stack.";
diff --git a/tensorflow/core/profiler/internal/tfprof_utils.h b/tensorflow/core/profiler/internal/tfprof_utils.h
index afca3df7f8..d4f80afce0 100644
--- a/tensorflow/core/profiler/internal/tfprof_utils.h
+++ b/tensorflow/core/profiler/internal/tfprof_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
#include <string>
#include <vector>
@@ -72,4 +72,4 @@ string QueryDoc(const string& cmd, const Options& opts);
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_
diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h
index 463f5b3c3a..d61deb72ac 100644
--- a/tensorflow/core/profiler/tfprof_options.h
+++ b/tensorflow/core/profiler/tfprof_options.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
#include <set>
#include <string>
@@ -183,4 +183,4 @@ tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
} // namespace tfprof
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 9b51db1362..3e7289bd91 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -292,7 +292,10 @@ message RecvTensorRequest {
// into a RunGraph call on the same WorkerService.
int64 step_id = 1;
- // A key that identifies the tensor to be received.
+ // A key identifying the channel to receive tensors from. A RecvTensor request
+ // retrieves one tensor from the channel, but multiple tensors can be sent and
+ // received over the same channel with multiple RecvTensor requests. See
+ // rendezvous.h for details.
string rendezvous_key = 2;
// If true, use an out-of-band DMA mechanism to transfer the
@@ -307,6 +310,16 @@ message RecvTensorRequest {
// Optional information needed by the RPC subsystem.
google.protobuf.Any transport_options = 6;
+
+ // Unique identifier for this request. Every RecvTensorRequest must have a
+ // unique request_id, and retried RecvTensorRequests must have the same
+ // request_id. If request_id is zero, retry detection is disabled.
+ //
+ // Retried RecvTensorRequests are problematic because a RecvTensor with no
+ // corresponding sender will wait forever, and the tensor may have been
+ // delivered to a previous retry. Workers use request_ids to reject retried
+ // RecvTensor requests instead of waiting forever.
+ int64 request_id = 7;
}
message RecvTensorResponse {
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index adeb080dde..67da7bf452 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -94,10 +94,12 @@ limitations under the License.
// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
// whether default-valued attrs have been stripped from the nodes in the
// GraphDef. (7dec2017)
+// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
+// deprecated in favor of V2 ops. (2018/01/23)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 25
+#define TF_GRAPH_DEF_VERSION 26
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h
index 121c7063c9..928ae8a4e9 100644
--- a/tensorflow/core/util/command_line_flags.h
+++ b/tensorflow/core/util/command_line_flags.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#ifndef TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#define TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
#include <functional>
#include <string>
@@ -134,4 +134,4 @@ class Flags {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
+#endif // TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index d30ab3f4da..53087821d7 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -52,26 +52,25 @@ struct BeamProbability {
float label;
};
+template <class CTCBeamState>
+class BeamRoot;
+
template <class CTCBeamState = EmptyBeamState>
struct BeamEntry {
- // Default constructor does not create a vector of children.
- BeamEntry() : parent(nullptr), label(-1) {}
- // Constructor giving parent, label, and number of children does
- // create a vector of children. The object pointed to by p
- // cannot be copied and should not be moved, otherwise parent will
- // become invalid.
- BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
+ // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
+ friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
+ BeamEntry<CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero; }
// Return the child at the given index, or construct a new one in-place if
// none was found.
BeamEntry& GetChild(int ind) {
auto entry = children.emplace(ind, nullptr);
auto& child_entry = entry.first->second;
- // If this is a new child, populate the uniqe_ptr.
+ // If this is a new child, populate the BeamEntry<CTCBeamState>*.
if (entry.second) {
- child_entry.reset(new BeamEntry(this, ind));
+ child_entry = beam_root->AddEntry(this, ind);
}
- return *(child_entry.get());
+ return *child_entry;
}
std::vector<int> LabelSeq(bool merge_repeated) const {
std::vector<int> labels;
@@ -90,15 +89,45 @@ struct BeamEntry {
BeamEntry<CTCBeamState>* parent;
int label;
- gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
+ // All instances of child BeamEntry are owned by *beam_root.
+ gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
BeamProbability oldp;
BeamProbability newp;
CTCBeamState state;
private:
+ // Constructor giving parent, label, and the beam_root.
+ // The object pointed to by p cannot be copied and should not be moved,
+ // otherwise parent will become invalid.
+ // This private constructor is only called through the factory method
+ // BeamRoot<CTCBeamState>::AddEntry().
+ BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
+ : parent(p), label(l), beam_root(beam_root) {}
+ BeamRoot<CTCBeamState>* beam_root;
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
};
+// This class owns all instances of BeamEntry. This is used to avoid recursive
+// destructor call during destruction.
+template <class CTCBeamState = EmptyBeamState>
+class BeamRoot {
+ public:
+ BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
+ BeamRoot(const BeamRoot&) = delete;
+ BeamRoot& operator=(const BeamRoot&) = delete;
+
+ BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
+ auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
+ beam_entries_.emplace_back(new_entry);
+ return new_entry;
+ }
+ BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
+
+ private:
+ BeamEntry<CTCBeamState>* root_entry_ = nullptr;
+ std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
+};
+
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
template <class CTCBeamState = EmptyBeamState>
class BeamComparer {
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index 372f25a143..709c65fc96 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -16,11 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
+#include <algorithm>
#include <cmath>
+#include <limits>
#include <memory>
+#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
// but we calculate it recursively for speed purposes.
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
+ typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability BeamProbability;
public:
@@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
float label_selection_margin_ = -1; // -1 means unlimited.
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
- std::unique_ptr<BeamEntry> beam_root_;
+ std::unique_ptr<BeamRoot> beam_root_;
BaseBeamScorer<CTCBeamState>* beam_scorer_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
@@ -367,15 +372,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
// This beam root, and all of its children, will be in memory until
// the next reset.
- beam_root_.reset(new BeamEntry(nullptr, -1));
- beam_root_->newp.total = 0.0; // ln(1)
- beam_root_->newp.blank = 0.0; // ln(1)
+ beam_root_.reset(new BeamRoot(nullptr, -1));
+ beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
+ beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
// Add the root as the initial leaf.
- leaves_.push(beam_root_.get());
+ leaves_.push(beam_root_->RootEntry());
// Call initialize state on the root object.
- beam_scorer_->InitializeState(&beam_root_->state);
+ beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
}
template <typename CTCBeamState, typename CTCBeamComparer>
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index 5b28aeb70a..b8bab69053 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
+#include <memory>
+#include <vector>
+
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h
new file mode 100644
index 0000000000..f787687f66
--- /dev/null
+++ b/tensorflow/core/util/cuda_device_functions.h
@@ -0,0 +1,499 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+
+/**
+ * Wrappers and helpers for CUDA device code.
+ *
+ * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
+ * backwards compatibility, see go/volta-porting for details.
+ * Provides atomic operations on types that aren't natively supported.
+ */
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+#include <complex>
+#include "cuda/include/cuda.h"
+#include "cuda/include/device_functions.h"
+#include "tensorflow/core/platform/types.h"
+
+#if CUDA_VERSION >= 7050
+#include "cuda/include/cuda_fp16.h"
+#endif // CUDA_VERSION >= 7050
+
+namespace tensorflow {
+
+namespace detail {
+
+// Helper for range-based for loop using 'delta' increments.
+// Usage: see CudaGridRange?() functions below.
+template <typename T>
+class CudaGridRange {
+ struct Iterator {
+ __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
+ __device__ T operator*() const { return index_; }
+ __device__ Iterator& operator++() {
+ index_ += delta_;
+ return *this;
+ }
+ __device__ bool operator!=(const Iterator& other) const {
+ bool greater = index_ > other.index_;
+ bool less = index_ < other.index_;
+ // Anything past an end iterator (delta_ == 0) is equal.
+ // In range-based for loops, this optimizes to 'return less'.
+ if (!other.delta_) {
+ return less;
+ }
+ if (!delta_) {
+ return greater;
+ }
+ return less || greater;
+ }
+
+ private:
+ T index_;
+ const T delta_;
+ };
+
+ public:
+ __device__ CudaGridRange(T begin, T delta, T end)
+ : begin_(begin), delta_(delta), end_(end) {}
+
+ __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
+ __device__ Iterator end() const { return Iterator{end_, 0}; }
+
+ private:
+ T begin_;
+ T delta_;
+ T end_;
+};
+
+} // namespace detail
+
+// Helper to visit indices in the range 0 <= i < count, using the x-coordinate
+// of the global thread index. That is, each index i is visited by all threads
+// with the same x-coordinate.
+// Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
+ return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
+ gridDim.x * blockDim.x, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
+// Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
+ return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
+ gridDim.y * blockDim.y, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
+// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
+ return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
+ gridDim.z * blockDim.z, count);
+}
+
+// Mask for all 32 threads in a warp.
+const unsigned kCudaWarpAll = 0xffffffff;
+
+// Returns the warp lane ID of the calling thread
+__device__ inline unsigned CudaLaneId() {
+ unsigned int lane_id;
+ asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
+ return lane_id;
+}
+
+namespace detail {
+// Returns true if mask is a valid parameter for __shfl*sync to return a well
+// defined value, assuming the calling lane will read from src_lane as part of
+// the shuffle operation.
+//
+// Specifically, returns true iff mask has the calling lane bit and the src_lane
+// bit set, and the src_lane calls this function with the same mask value
+// (required for the two threads to wait for each other).
+//
+// On Volta, for some invalid masks, this function hangs or returns false
+// positives, because the implementation shuffles with the same mask that
+// we are validating. Run on Pascal if you suspect that the mask is incorrect.
+__device__ inline bool CudaValidateShuffleSyncMask(unsigned mask,
+ unsigned src_lane) {
+ unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane;
+#if CUDA_VERSION >= 9000
+ unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
+#else
+ unsigned src_lane_mask = __shfl(mask, src_lane);
+#endif
+ return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
+}
+
+// Returns the actual source lane for shuffle.
+__device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) {
+ int lane_id = CudaLaneId();
+ int lane_base = lane_id & ~width + 1;
+ int lane_offset = src_lane & width - 1;
+ return lane_base + lane_offset;
+}
+
+// Returns the source lane for shuffle up.
+__device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) {
+ unsigned lane_id = CudaLaneId();
+ if ((lane_id & width - 1) < delta) {
+ return lane_id;
+ }
+ return lane_id - delta;
+}
+
+// Returns the source lane for shuffle down.
+__device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta,
+ int width) {
+ unsigned lane_id = CudaLaneId();
+ if ((lane_id & width - 1) + delta >= width) {
+ return lane_id;
+ }
+ return lane_id + delta;
+}
+
+// Returns the source lane for shuffle xor.
+__device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) {
+ int lane_id = CudaLaneId();
+ int src_lane = lane_id ^ lane_mask;
+ if (src_lane > (lane_id | width - 1)) {
+ return lane_id;
+ }
+ return src_lane;
+}
+} // namespace detail
+
+// For all *_sync wrappers below, it is illegal to synchronize threads from
+// different program locations, because that is not supported before sm_70.
+// In other words, all threads in 'mask' must call the functions in convergence.
+// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
+//
+// It is also illegal to shuffle with a mask that produces an undefined result
+// for any of the threads. Specifically, all source threads of the shuffle
+// must have their corresponding bit in 'mask' set.
+
+// Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
+__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ __syncwarp(mask);
+#endif
+}
+
+// Wrapper for __ballot_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __ballot_sync(mask, pred);
+#else
+ return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec.
+#endif
+}
+
+// Wrapper for __any_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline int CudaAnySync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __any_sync(mask, pred);
+#else
+ return __any(pred);
+#endif
+}
+
+// Wrapper for __all_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline int CudaAllSync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __all_sync(mask, pred);
+#else
+ return __all(pred);
+#endif
+}
+
+// Wrapper for __shfl_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleGetSrcLane(src_lane, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_sync(mask, value, src_lane, width);
+#else
+ return __shfl(value, src_lane, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleSync(unsigned mask, double value,
+ int src_lane, int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleSync(mask, hi, src_lane, width);
+ lo = CudaShuffleSync(mask, lo, src_lane, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleUpGetSrcLane(delta, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_up_sync(mask, value, delta, width);
+#else
+ return __shfl_up(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleUpSync(unsigned mask, double value,
+ unsigned delta,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleUpSync(mask, hi, delta, width);
+ lo = CudaShuffleUpSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
+// in convergence, see comment above for details.
+template <typename T>
+__device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleDownGetSrcLane(delta, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_down_sync(mask, value, delta, width);
+#else
+ return __shfl_down(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleDownSync(unsigned mask, double value,
+ unsigned delta,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleDownSync(mask, hi, delta, width);
+ lo = CudaShuffleDownSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, lane_mask, width);
+#else
+ return __shfl_xor(value, lane_mask, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleXorSync(unsigned mask, double value,
+ int lane_mask,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
+ lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __ldg.
+template <typename T>
+__host__ __device__ T CudaLdg(const T* address) {
+#if __CUDA_ARCH__ >= 350
+ return __ldg(address);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline bool CudaLdg(const bool* address) {
+ return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
+}
+
+__host__ __device__ inline std::complex<float> CudaLdg(
+ const std::complex<float>* address) {
+#if __CUDA_ARCH__ >= 350
+ float2 mem = __ldg(reinterpret_cast<const float2*>(address));
+ return std::complex<float>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline std::complex<double> CudaLdg(
+ const std::complex<double>* address) {
+#if __CUDA_ARCH__ >= 350
+ double2 mem = __ldg(reinterpret_cast<const double2*>(address));
+ return std::complex<double>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+// Zeroes count elements starting at ptr using all threads of a 1-D grid.
+// Note: this function does not synchronize, and therefore the memory range is
+// not guaranteed to be zero until the next kernel launch.
+template <typename T>
+__global__ void SetZero(const int count, T* ptr) {
+ // Check that the grid is one dimensional and index doesn't overflow.
+ assert(blockDim.y == 1 && blockDim.z == 1);
+ assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
+ for (int i : CudaGridRangeX(count)) {
+ ptr[i] = T(0);
+ }
+}
+
+namespace detail {
+// Helper function for atomic accumulation implemented as CAS.
+template <typename T, typename F>
+__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
+ T old = *ptr;
+ T assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(ptr, assumed, accumulate(assumed));
+ } while (assumed != old);
+ return old;
+}
+
+// Overload for floating point (using integer comparison to handle NaN
+// correctly).
+template <typename F>
+__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
+ return __float_as_int(
+ CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
+ return __float_as_int(accumulate(__int_as_float(a)));
+ }));
+}
+template <typename F>
+__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
+ return __longlong_as_double(CudaAtomicCasHelper(
+ reinterpret_cast<tensorflow::uint64*>(ptr),
+ [accumulate](tensorflow::uint64 a) {
+ return __double_as_longlong(accumulate(__longlong_as_double(a)));
+ }));
+}
+
+template <typename From, typename To>
+using ToTypeIfConvertible =
+ typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
+
+} // namespace detail
+
+// CUDA provides atomic ops, but not for all types. We provide wrappers
+// for some ops and provide implementation for all reasonable types.
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) {
+ return atomicAdd(ptr, value);
+}
+#if __CUDA_ARCH__ < 600
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ return detail::CudaAtomicCasHelper(ptr,
+ [value](double a) { return a + value; });
+}
+#elif __clang__
+// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
+// see https://reviews.llvm.org/D39638
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ double result;
+ asm volatile("atom.add.f64 %0, [%1], %2;"
+ : "=d"(result)
+ : "l"(ptr), "d"(value)
+ : "memory");
+ return result;
+}
+#endif
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) {
+ return atomicSub(ptr, value);
+}
+// Specializations of substraction which add the negative value.
+__device__ inline float CudaAtomicSub(float* ptr, float value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline double CudaAtomicSub(double* ptr, double value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) {
+ return atomicMax(ptr, value);
+}
+#if __CUDA_ARCH__ < 320
+__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](tensorflow::uint64 a) { return max(a, value); });
+}
+#endif
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
+}
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 3e32ec7973..18a4c008f1 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -18,299 +18,133 @@ limitations under the License.
#if GOOGLE_CUDA
-#include <algorithm>
+#include "tensorflow/core/util/cuda_device_functions.h"
+#include "tensorflow/core/util/cuda_launch_config.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "cuda/include/cuda.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
-#include "tensorflow/core/platform/types.h"
+// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i : ::tensorflow::CudaGridRangeX<int>(n))
+// Deprecated, use 'for(int i : CudaGridRange?(n))' instead.
+#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
+ for (int i : ::tensorflow::CudaGridRange##axis<int>(n))
-// Mask for all 32 threads in a warp.
-#define CUDA_WARP_ALL 0xFFFFFFFF
-
-#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
-// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
-// that operates at the warp-scope. This is required to ensure visibility of
-// reads/writes among threads that can make indepenent progress on Volta.
-// For previous CUDA versions these synchronizations not necessary, and we
-// define an empty function as a convenience for backward compatibility.
-__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {}
-
-// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
-// favor of synchronizing versions. These ensure that all warp lanes specified
-// in mask execute the intrinsic in convergence. Here we provide legacy mappings
-// to the less-verbose routines provided in previous versions of CUDA.
-#define __ballot_sync(mask, predicate) __ballot(predicate)
-#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
-#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
-#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
-#define __shfl_xor_sync(mask, val, laneMask, width) \
- __shfl_xor(val, laneMask, width)
-#endif
-
-// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
-// GetCuda3DLaunchConfig:
-//
-// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
-// version uses heuristics without any knowledge of the device kernel, the other
-// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
-// launch parameters that maximize occupancy. Currently, only the maximum
-// occupancy version of GetCuda3DLaunchConfig is available.
-//
-// For large number of work elements, the convention is that each kernel would
-// iterate through its assigned range. The return value of GetCudaLaunchConfig
-// is struct CudaLaunchConfig, which contains all the information needed for the
-// kernel launch, including: virtual number of threads, the number of threads
-// per block and number of threads per block used inside <<< >>> of a kernel
-// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
-// as CudaLaunchConfig. The only difference is the dimension. The macros
-// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
-//
-/* Sample code:
-
-__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
- CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
- do_your_job_here;
- }
+namespace tensorflow {
+__host__ __device__ inline tensorflow::bfloat16 CudaLdg(
+ const tensorflow::bfloat16* address) {
+ tensorflow::bfloat16 return_value;
+ return_value.value = CudaLdg(reinterpret_cast<const uint16_t*>(address));
+ return return_value;
}
-__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- do_your_job_here;
- }
- }
+template <typename T>
+__host__ __device__ inline T ldg(const T* ptr) {
+ return CudaLdg(ptr);
}
-__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
- do_your_job_here;
- }
- }
- }
+template <typename T>
+__host__ __device__ inline const T& tf_min(const T& x, const T& y) {
+ return x < y ? x : y;
}
-void MyDriverFunc(const GPUDevice &d) {
- // use heuristics
- CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
- Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
- Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
-
- // maximize occupancy
- CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
- Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
- MyKernel1D, 0, 0);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
- Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
- MyKernel1D, 0, 0);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+template <typename T>
+__host__ __device__ inline const T& tf_max(const T& x, const T& y) {
+ return x < y ? y : x;
}
-// See the test for this for more example:
-//
-https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
-
-*/
-
-#define CUDA_1D_KERNEL_LOOP(i, n) \
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
- i += blockDim.x * gridDim.x)
-
-#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
- for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
- i += blockDim.axis * gridDim.axis)
-
-#define DIV_UP(a, b) (((a) + (b)-1) / (b))
-
-namespace tensorflow {
-
-typedef Eigen::GpuDevice GPUDevice;
-
-struct CudaLaunchConfig {
- // Logical number of thread that works on the elements. If each logical
- // thread works on exactly a single element, this is the same as the working
- // element count.
- int virtual_thread_count = -1;
- // Number of threads per block.
- int thread_per_block = -1;
- // Number of blocks for Cuda kernel launch.
- int block_count = -1;
-};
-
-// Calculate the Cuda launch config we should use for a kernel launch.
-// This is assuming the kernel is quite simple and will largely be
-// memory-limited.
-// REQUIRES: work_element_count > 0.
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- const int virtual_thread_count = work_element_count;
- const int physical_thread_count = std::min(
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
- virtual_thread_count);
- const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
- const int block_count =
- std::min(DIV_UP(physical_thread_count, thread_per_block),
- d.getNumCudaMultiProcessors());
-
- config.virtual_thread_count = virtual_thread_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+// Overloads of the above functions for float and double.
+__host__ __device__ inline float tf_min(float x, float y) {
+ return fminf(x, y);
}
-
-// Calculate the Cuda launch config we should use for a kernel launch. This
-// variant takes the resource limits of func into account to maximize occupancy.
-// REQUIRES: work_element_count > 0.
-template <typename DeviceFunc>
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size,
- int block_size_limit) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- int block_count = 0;
- int thread_per_block = 0;
-
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
- block_count =
- std::min(block_count, DIV_UP(work_element_count, thread_per_block));
-
- config.virtual_thread_count = work_element_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+__host__ __device__ inline double tf_min(double x, double y) {
+ return fmin(x, y);
+}
+__host__ __device__ inline float tf_max(float x, float y) {
+ return fmaxf(x, y);
+}
+__host__ __device__ inline double tf_max(double x, double y) {
+ return fmax(x, y);
}
-struct Cuda2DLaunchConfig {
- dim3 virtual_thread_count = dim3(0, 0, 0);
- dim3 thread_per_block = dim3(0, 0, 0);
- dim3 block_count = dim3(0, 0, 0);
-};
-
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
- const GPUDevice& d) {
- Cuda2DLaunchConfig config;
-
- if (xdim <= 0 || ydim <= 0) {
- return config;
- }
-
- const int kThreadsPerBlock = 256;
- int block_cols = std::min(xdim, kThreadsPerBlock);
- // ok to round down here and just do more loops in the kernel
- int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
-
- const int physical_thread_count =
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
-
- const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
-
- config.virtual_thread_count = dim3(xdim, ydim, 1);
- config.thread_per_block = dim3(block_cols, block_rows, 1);
-
- int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
+__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value,
+ int src_lane,
+ int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
+}
- config.block_count = dim3(
- grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
- return config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
}
-// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
-// This variant takes the resource limits of func into account to maximize
-// occupancy.
-using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
+}
-template <typename DeviceFunc>
-inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
- int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- Cuda3DLaunchConfig config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
+ unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
+}
- if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
- return config;
+namespace detail {
+// Overload of above function for half. Note that we don't have
+// atomicCAS() for anything less than 32 bits, so we need to include the
+// other 16 bits in the operation.
+//
+// This version is going to be very slow
+// under high concurrency, since most threads will be spinning on failing
+// their compare-and-swap tests. (The fact that we get false sharing on the
+// neighboring fp16 makes this even worse.) If you are doing a large reduction,
+// you are much better off with doing the intermediate steps in fp32 and then
+// switching to fp16 as late as you can in the calculations.
+//
+// Note: Assumes little endian.
+template <typename F>
+__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
+#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
+ static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
+#endif
+ namespace half_impl = Eigen::half_impl;
+ intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
+ assert(!(intptr & 0x1)); // should be 2-aligned.
+ if (intptr & 0x2) {
+ // The half is in the second part of the uint32 (upper 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr - 2);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
+ unsigned short high = static_cast<unsigned short>(arg >> 16);
+ Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
+ return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
+ });
+ return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
+ } else {
+ // The half is in the first part of the uint32 (lower 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
+ unsigned short low = static_cast<unsigned short>(arg & 0xffff);
+ Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
+ return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
+ });
+ return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
}
-
- int dev;
- cudaGetDevice(&dev);
- cudaDeviceProp deviceProp;
- cudaGetDeviceProperties(&deviceProp, dev);
- int xthreadlimit = deviceProp.maxThreadsDim[0];
- int ythreadlimit = deviceProp.maxThreadsDim[1];
- int zthreadlimit = deviceProp.maxThreadsDim[2];
- int xgridlimit = deviceProp.maxGridSize[0];
- int ygridlimit = deviceProp.maxGridSize[1];
- int zgridlimit = deviceProp.maxGridSize[2];
-
- int block_count = 0;
- int thread_per_block = 0;
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
-#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
- int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
- int threadsy =
- MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
- int threadsz =
- MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
- zthreadlimit);
-
- int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
- int blocksy =
- MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
- int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
- DIV_UP(zdim, threadsz), zgridlimit);
-#undef MIN3
-
- config.virtual_thread_count = dim3(xdim, ydim, zdim);
- config.thread_per_block = dim3(threadsx, threadsy, threadsz);
- config.block_count = dim3(blocksx, blocksy, blocksz);
- return config;
}
+} // namespace detail
-template <typename DeviceFunc>
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
- int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
- dynamic_shared_memory_size, block_size_limit);
+__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a + value; });
}
-
-// Returns a raw reference to the current cuda stream. Required by a
-// number of kernel calls (for which StreamInterface* does not work), i.e.
-// CUB and certain cublas primitives.
-inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
- const cudaStream_t* ptr = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(context->op_device_context()
- ->stream()
- ->implementation()
- ->CudaStreamMemberHack()));
- return *ptr;
+__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a - value; });
}
namespace cuda_helper {
-
template <typename IntType>
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
IntType* orig = first;
@@ -330,495 +164,8 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
-
} // namespace cuda_helper
-
-template <typename T>
-__device__ __host__ inline T ldg(const T* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return __ldg(address);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<float> ldg(
- const std::complex<float>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- float2 mem = __ldg(reinterpret_cast<const float2*>(address));
- return std::complex<float>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<double> ldg(
- const std::complex<double>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- double2 mem = __ldg(reinterpret_cast<const double2*>(address));
- return std::complex<double>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return Eigen::half_impl::raw_uint16_to_half(
- __ldg(reinterpret_cast<const uint16_t*>(address)));
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline tensorflow::bfloat16 ldg(
- const tensorflow::bfloat16* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- tensorflow::bfloat16 return_value;
- asm volatile("ld.global.nc.u16 %0, [%1];"
- : "=h"(return_value.value)
- : "l"(address));
- return return_value;
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline bool ldg(const bool* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return *reinterpret_cast<const bool*>(
- __ldg(reinterpret_cast<const char*>(address)));
-#else
- return *address;
-#endif
-}
-
-// CUDA provides atomic ops, but not for all types. We provide wrappers
-// for some ops and provide implementation for all reasonable types.
-#define CUDA_ATOMIC_WRAPPER(op, T) \
- __device__ __forceinline__ T CudaAtomic##op(T* address, T val)
-
-#define USE_CUDA_ATOMIC(op, T) \
- CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
-
-// For atomicAdd.
-USE_CUDA_ATOMIC(Add, int32);
-USE_CUDA_ATOMIC(Add, uint32);
-USE_CUDA_ATOMIC(Add, uint64);
-USE_CUDA_ATOMIC(Add, float);
-
-// For atomicMax.
-USE_CUDA_ATOMIC(Max, int32);
-USE_CUDA_ATOMIC(Max, uint32);
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
-USE_CUDA_ATOMIC(Max, uint64);
-#else
-// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
-// 350. If not satisfied, we provide a custom implementation using atomicCAS().
-CUDA_ATOMIC_WRAPPER(Max, uint64) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed, max(val, assumed));
- } while (assumed != old);
-
- return old;
-}
-#endif
-
-// Custom implementation of atomicAdd for double.
-// This implementation is copied from CUDA manual.
-CUDA_ATOMIC_WRAPPER(Add, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val + __longlong_as_double(assumed)));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- return __longlong_as_double(old);
-}
-
-// Custom implementation of atomicAdd for std::complex<float>.
-// This implementation performs to atomic additions on the components.
-CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- float2* addr_as_float2 = reinterpret_cast<float2*>(address);
- float2* val_as_float2 = reinterpret_cast<float2*>(&val);
- CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x);
- CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y);
-#else
- static_assert(sizeof(std::complex<float>) == 2 * sizeof(float),
- "Unable to compile CudaAtomicAdd for complex64 because "
- "sizeof(complex64) != 2*sizeof(float32)");
- float* addr_as_float = reinterpret_cast<float*>(address);
- float* val_as_float = reinterpret_cast<float*>(&val);
- CudaAtomicAdd(addr_as_float, *val_as_float);
- CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1));
-#endif
-#endif
- return *address;
-}
-
-// Custom implementation of atomicAdd for std::complex<double>.
-// This implementation performs to atomic additions on the components
-// using the double atomic wrapper above.
-CUDA_ATOMIC_WRAPPER(Add, complex128) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- double2* addr_as_double2 = reinterpret_cast<double2*>(address);
- double2* val_as_double2 = reinterpret_cast<double2*>(&val);
- CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x);
- CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y);
-#else
- static_assert(sizeof(std::complex<double>) == 2 * sizeof(double),
- "Unable to compile CudaAtomicAdd for complex128 because "
- "sizeof(complex128) != 2*sizeof(float64)");
- double* addr_as_double = reinterpret_cast<double*>(address);
- double* val_as_double = reinterpret_cast<double*>(&val);
- CudaAtomicAdd(addr_as_double, *val_as_double);
- CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1));
-#endif
-#endif
- return *address;
-}
-
-// Helper functions for CudaAtomicAdd(half*, half), below.
-//
-// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
-// for a more efficient implementation, assuming that adding -0.0
-// will never harm the neighboring value. In this version, we take special
-// care to guarantee the bits of the untouched value are unchanged.
-inline __device__ uint32 add_to_low_half(uint32 val, float x) {
- Eigen::half low_half;
- low_half.x = static_cast<uint16>(val & 0xffffu);
- low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
- return (val & 0xffff0000u) | low_half.x;
-}
-
-inline __device__ uint32 add_to_high_half(uint32 val, float x) {
- Eigen::half high_half;
- high_half.x = static_cast<uint16>(val >> 16);
- high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
- return (val & 0xffffu) | (high_half.x << 16);
-}
-
-// Custom implementation of atomicAdd for half. Note that we don't have
-// atomicCAS() for anything less than 32 bits, so we need to include the
-// other 16 bits in the operation.
-//
-// Unlike the other atomic adds, this version is going to be very slow
-// under high concurrency, since most threads will be spinning on failing
-// their compare-and-swap tests. (The fact that we get false sharing on the
-// neighboring fp16 makes this even worse.) If you are doing a large reduction,
-// you are much better off with doing the intermediate steps in fp32 and then
-// switching to fp16 as late as you can in the calculations.
-//
-// Note: Assumes little endian.
-CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
- float val_as_float(val);
- intptr_t address_int = reinterpret_cast<intptr_t>(address);
- if ((address_int & 0x2) == 0) {
- // The half is in the first part of the uint32 (lower 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_low_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old & 0xffffu;
- return ret;
- } else {
- // The half is in the second part of the uint32 (upper 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_high_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old >> 16;
- return ret;
- }
-}
-
-template <typename T>
-__global__ void SetZero(const int nthreads, T* bottom_diff) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
-}
-
-// For atomicSub.
-
-// Custom implementation for sub by just negating the value.
-#define WRAPPED_ATOMIC_SUB(T) \
- CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }
-
-WRAPPED_ATOMIC_SUB(uint64);
-WRAPPED_ATOMIC_SUB(int32);
-WRAPPED_ATOMIC_SUB(uint32);
-WRAPPED_ATOMIC_SUB(Eigen::half);
-WRAPPED_ATOMIC_SUB(float);
-WRAPPED_ATOMIC_SUB(double);
-
-CUDA_ATOMIC_WRAPPER(Sub, complex64) {
- const std::complex<float> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-CUDA_ATOMIC_WRAPPER(Sub, complex128) {
- const std::complex<double> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-#undef WRAPPED_ATOMIC_SUB
-
-// For atomicMul.
-CUDA_ATOMIC_WRAPPER(Mul, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(val * __int_as_float(assumed)));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val * __longlong_as_double(assumed)));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-// For atomicDiv.
-CUDA_ATOMIC_WRAPPER(Div, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(__int_as_float(assumed) / val));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Div, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(__longlong_as_double(assumed) / val));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-#undef USE_CUDA_ATOMIC
-#undef CUDA_ATOMIC_WRAPPER
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
- return x > y ? y : x;
-}
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
- return x < y ? y : x;
-}
-
-__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
- int predicate) {
- return __ballot_sync(mask, predicate);
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
- int srcLane,
- int width = warpSize) {
- return __shfl_sync(mask, value, srcLane, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value,
- int srcLane,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_sync(mask, hi, srcLane, width);
- lo = __shfl_sync(mask, lo, srcLane, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_up_sync(mask, value, delta, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value,
- int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_up_sync(mask, hi, delta, width);
- lo = __shfl_up_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_down_sync(mask, value, delta, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
- unsigned mask, Eigen::half value, int delta, int width = warpSize) {
- return Eigen::half(
- __shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
- double value, int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_down_sync(mask, hi, delta, width);
- lo = __shfl_down_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
- int laneMask,
- int width = warpSize) {
- return __shfl_xor_sync(mask, value, laneMask, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
- unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
- return Eigen::half(
- __shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
- double value, int laneMask,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_xor_sync(mask, hi, laneMask, width);
- lo = __shfl_xor_sync(mask, lo, laneMask, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
} // namespace tensorflow
-#undef DIV_UP
-
#endif // GOOGLE_CUDA
-
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
index 6991554eff..bd4c356ea0 100644
--- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
@@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
@@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
if (z < 0) { // z might overflow when testing extreme case
break;
}
@@ -87,6 +87,44 @@ __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
}
}
+__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) {
+ unsigned lane_id = CudaLaneId();
+ for (int width = warpSize; width > 1; width /= 2) {
+ auto check_result = [&](const char* op_name, int param, unsigned actual,
+ unsigned expected) {
+ if (actual != expected) {
+ printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n",
+ op_name, param, width, lane_id, actual, expected);
+ CudaAtomicAdd(failure_count, 1);
+ }
+ };
+ for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) {
+ unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width);
+ unsigned expect_lane =
+ CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
+ check_result("Shuffle", src_lane, actual_lane, expect_lane);
+ }
+ for (unsigned delta = 0; delta <= warpSize; ++delta) {
+ unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width);
+ unsigned expect_lane =
+ CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
+ check_result("ShuffleUp", delta, actual_lane, expect_lane);
+ }
+ for (unsigned delta = 0; delta <= warpSize; ++delta) {
+ unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width);
+ unsigned expect_lane =
+ CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
+ check_result("ShuffleDown", delta, actual_lane, expect_lane);
+ }
+ for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) {
+ unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width);
+ unsigned expect_lane =
+ CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
+ check_result("ShuffleXor", lane_lane, actual_lane, expect_lane);
+ }
+ }
+}
+
} // namespace
class CudaLaunchConfigTest : public ::testing::Test {
@@ -94,7 +132,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
const int bufsize = 1024;
int* outbuf = nullptr;
Eigen::CudaStreamDevice stream;
- GPUDevice d = GPUDevice(&stream);
+ Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
virtual void SetUp() {
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
@@ -229,6 +267,16 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) {
#undef TEST_LAUNCH_PARAMETER
}
+TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) {
+ unsigned* failure_count;
+ ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess);
+ *failure_count = 0;
+ CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count);
+ ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess);
+ ASSERT_EQ(*failure_count, 0);
+ cudaFree(failure_count);
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h
new file mode 100644
index 0000000000..3ea33ee6cf
--- /dev/null
+++ b/tensorflow/core/util/cuda_launch_config.h
@@ -0,0 +1,284 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "cuda/include/cuda.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+
+// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
+// GetCuda3DLaunchConfig:
+//
+// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
+// version uses heuristics without any knowledge of the device kernel, the other
+// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
+// launch parameters that maximize occupancy. Currently, only the maximum
+// occupancy version of GetCuda3DLaunchConfig is available.
+//
+// For large number of work elements, the convention is that each kernel would
+// iterate through its assigned range. The return value of GetCudaLaunchConfig
+// is struct CudaLaunchConfig, which contains all the information needed for the
+// kernel launch, including: virtual number of threads, the number of threads
+// per block and number of threads per block used inside <<< >>> of a kernel
+// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
+// as CudaLaunchConfig. The only difference is the dimension. The macros
+// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
+//
+/* Sample code:
+
+__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
+ CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
+ do_your_job_here;
+ }
+}
+
+__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ do_your_job_here;
+ }
+ }
+}
+
+__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ do_your_job_here;
+ }
+ }
+ }
+}
+
+void MyDriverFunc(const Eigen::GpuDevice &d) {
+ // use heuristics
+ CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
+ Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
+ Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
+
+ // maximize occupancy
+ CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
+ Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
+ MyKernel1D, 0, 0);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
+ Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
+ MyKernel1D, 0, 0);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+}
+
+// See the test for this for more example:
+//
+https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+
+*/
+
+namespace tensorflow {
+
+inline int DivUp(int a, int b) { return (a + b - 1) / b; }
+
+struct CudaLaunchConfig {
+ // Logical number of thread that works on the elements. If each logical
+ // thread works on exactly a single element, this is the same as the working
+ // element count.
+ int virtual_thread_count = -1;
+ // Number of threads per block.
+ int thread_per_block = -1;
+ // Number of blocks for Cuda kernel launch.
+ int block_count = -1;
+};
+
+// Calculate the Cuda launch config we should use for a kernel launch.
+// This is assuming the kernel is quite simple and will largely be
+// memory-limited.
+// REQUIRES: work_element_count > 0.
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ const int virtual_thread_count = work_element_count;
+ const int physical_thread_count = std::min(
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
+ virtual_thread_count);
+ const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
+ const int block_count =
+ std::min(DivUp(physical_thread_count, thread_per_block),
+ d.getNumCudaMultiProcessors());
+
+ config.virtual_thread_count = virtual_thread_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+// Calculate the Cuda launch config we should use for a kernel launch. This
+// variant takes the resource limits of func into account to maximize occupancy.
+// REQUIRES: work_element_count > 0.
+template <typename DeviceFunc>
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d,
+ DeviceFunc func,
+ size_t dynamic_shared_memory_size,
+ int block_size_limit) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ int block_count = 0;
+ int thread_per_block = 0;
+
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ block_count =
+ std::min(block_count, DivUp(work_element_count, thread_per_block));
+
+ config.virtual_thread_count = work_element_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+struct Cuda2DLaunchConfig {
+ dim3 virtual_thread_count = dim3(0, 0, 0);
+ dim3 thread_per_block = dim3(0, 0, 0);
+ dim3 block_count = dim3(0, 0, 0);
+};
+
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
+ const Eigen::GpuDevice& d) {
+ Cuda2DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0) {
+ return config;
+ }
+
+ const int kThreadsPerBlock = 256;
+ int block_cols = std::min(xdim, kThreadsPerBlock);
+ // ok to round down here and just do more loops in the kernel
+ int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
+
+ const int physical_thread_count =
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
+
+ const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
+
+ config.virtual_thread_count = dim3(xdim, ydim, 1);
+ config.thread_per_block = dim3(block_cols, block_rows, 1);
+
+ int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
+
+ config.block_count = dim3(
+ grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
+ return config;
+}
+
+// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
+// This variant takes the resource limits of func into account to maximize
+// occupancy.
+using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+
+template <typename DeviceFunc>
+inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
+ int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ Cuda3DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
+ return config;
+ }
+
+ int dev;
+ cudaGetDevice(&dev);
+ cudaDeviceProp deviceProp;
+ cudaGetDeviceProperties(&deviceProp, dev);
+ int xthreadlimit = deviceProp.maxThreadsDim[0];
+ int ythreadlimit = deviceProp.maxThreadsDim[1];
+ int zthreadlimit = deviceProp.maxThreadsDim[2];
+ int xgridlimit = deviceProp.maxGridSize[0];
+ int ygridlimit = deviceProp.maxGridSize[1];
+ int zgridlimit = deviceProp.maxGridSize[2];
+
+ int block_count = 0;
+ int thread_per_block = 0;
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); };
+
+ int threadsx = min3(xdim, thread_per_block, xthreadlimit);
+ int threadsy =
+ min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
+ int threadsz =
+ min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
+ zthreadlimit);
+
+ int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit);
+ int blocksy =
+ min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit);
+ int blocksz = min3(DivUp(block_count, (blocksx * blocksy)),
+ DivUp(zdim, threadsz), zgridlimit);
+
+ config.virtual_thread_count = dim3(xdim, ydim, zdim);
+ config.thread_per_block = dim3(threadsx, threadsy, threadsz);
+ config.block_count = dim3(blocksx, blocksy, blocksz);
+ return config;
+}
+
+template <typename DeviceFunc>
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
+ int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
+ dynamic_shared_memory_size, block_size_limit);
+}
+
+// Returns a raw reference to the current cuda stream. Required by a
+// number of kernel calls (for which StreamInterface* does not work), i.e.
+// CUB and certain cublas primitives.
+inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
+ const cudaStream_t* ptr = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ return *ptr;
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index fe59ec77ca..1b08f02267 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
+#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
+#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
#include <string>
#include <unordered_map>
@@ -94,4 +94,4 @@ bool TestFastParse(const string& serialized, Example* example);
} // namespace example
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
+#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h
index 8b3c6c5a3f..e511704962 100644
--- a/tensorflow/core/util/example_proto_helper.h
+++ b/tensorflow/core/util/example_proto_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
+#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
+#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
#include <string>
#include <vector>
@@ -314,4 +314,4 @@ class ParseSingleSequenceExampleAttrs {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
+#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h
index 5366623883..5846cae2fc 100644
--- a/tensorflow/core/util/matmul_autotune.h
+++ b/tensorflow/core/util/matmul_autotune.h
@@ -15,8 +15,8 @@ limitations under the License.
// The utility to check matmul autotune related flags.
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
+#ifndef TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
+#define TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
namespace tensorflow {
@@ -25,4 +25,4 @@ bool MatmulDoFP32ComputationFP16Input();
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
+#endif // TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h
index abca98f27b..25ecccd285 100644
--- a/tensorflow/core/util/strided_slice_op.h
+++ b/tensorflow/core/util/strided_slice_op.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
+#ifndef TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
+#define TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -62,4 +62,4 @@ Status ValidateStridedSliceOp(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
diff --git a/tensorflow/docs_src/api_guides/python/contrib.signal.md b/tensorflow/docs_src/api_guides/python/contrib.signal.md
index 85ef3ad134..0f7690f80a 100644
--- a/tensorflow/docs_src/api_guides/python/contrib.signal.md
+++ b/tensorflow/docs_src/api_guides/python/contrib.signal.md
@@ -28,14 +28,14 @@ The `axis` parameter to @{tf.contrib.signal.frame} allows you to frame tensors
with inner structure (e.g. a spectrogram):
```python
-# `magnitude_spectrograms` is a [batch_size, ?, 127] tensor of spectrograms. We
+# `magnitude_spectrograms` is a [batch_size, ?, 129] tensor of spectrograms. We
# would like to produce overlapping fixed-size spectrogram patches; for example,
# for use in a situation where a fixed size input is needed.
magnitude_spectrograms = tf.abs(tf.contrib.signal.stft(
signals, frame_length=256, frame_step=64, fft_length=256))
-# `spectrogram_patches` is a [batch_size, ?, 64, 127] tensor containing a
-# variable number of [64, 127] spectrogram patches per batch item.
+# `spectrogram_patches` is a [batch_size, ?, 64, 129] tensor containing a
+# variable number of [64, 129] spectrogram patches per batch item.
spectrogram_patches = tf.contrib.signal.frame(
magnitude_spectrograms, frame_length=64, frame_step=16, axis=1)
```
diff --git a/tensorflow/docs_src/api_guides/python/python_io.md b/tensorflow/docs_src/api_guides/python/python_io.md
index a5444408fe..06282e49d5 100644
--- a/tensorflow/docs_src/api_guides/python/python_io.md
+++ b/tensorflow/docs_src/api_guides/python/python_io.md
@@ -14,16 +14,16 @@ suitable if fast sharding or other non-sequential access is desired.
## TFRecords Format Details
-A TFRecords file contains a sequence of strings with CRC hashes. Each record
-has the format
+A TFRecords file contains a sequence of strings with CRC32C (32-bit CRC using
+the Castagnoli polynomial) hashes. Each record has the format
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
-and the records are concatenated together to produce the file. The CRC32s
-are [described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check),
-and the mask of a CRC is
+and the records are concatenated together to produce the file. CRCs are
+[described here](https://en.wikipedia.org/wiki/Cyclic_redundancy_check), and
+the mask of a CRC is
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index e13ddadab7..555a6837d8 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -115,7 +115,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -238,7 +238,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py2-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.5.0rc1-py3-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md
index fa7a94cc06..9f50be5b31 100644
--- a/tensorflow/docs_src/programmers_guide/saved_model.md
+++ b/tensorflow/docs_src/programmers_guide/saved_model.md
@@ -736,6 +736,7 @@ The `run` command provides the following two ways to pass inputs to the model:
* `--inputs` option enables you to pass numpy ndarray in files.
* `--input_exprs` option enables you to pass Python expressions.
+* `--input_examples` option enables you to pass `tf.train.Example`.
#### `--inputs`
@@ -789,19 +790,31 @@ inputs that match the dtype and shape of the model's `SignatureDef`s.
For example:
```bsh
-`input_key=[[1], [2], [3]]`
+`<input_key>=[[1],[2],[3]]`
```
In addition to Python expressions, you may also pass numpy functions. For
example:
```bsh
-input_key=np.ones((32, 32, 3))
+`<input_key>=np.ones((32,32,3))`
```
(Note that the `numpy` module is already available to you as `np`.)
+#### `--inputs_examples`
+
+To pass `tf.train.Example` as inputs, specify the `--input_examples` option.
+For each input key, it takes a list of dictionary, where each dictionary is an
+instance of `tf.train.Example`. The dictionary keys are the features and the
+values are the value lists for each feature.
+For example:
+
+```bsh
+`<input_key>=[{"age":[22,24],"education":["BS","MS"]}]`
+```
+
#### Save Output
By default, the SavedModel CLI writes output to stdout. If a directory is
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 46df5973e8..1214647797 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -92,7 +92,7 @@ android_binary(
filegroup(
name = "external_assets",
srcs = [
- "@inception5h//:model_files",
+ "@inception_v1//:model_files",
"@mobile_ssd//:model_files",
"@speech_commands//:model_files",
"@stylize//:model_files",
diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle
index 0e2cf65f53..d3b67eab52 100644
--- a/tensorflow/examples/android/download-models.gradle
+++ b/tensorflow/examples/android/download-models.gradle
@@ -9,7 +9,7 @@
*/
// hard coded model files
// LINT.IfChange
-def models = ['inception5h.zip',
+def models = ['inception_v1.zip',
'object_detection/ssd_mobilenet_v1_android_export.zip',
'stylize_v1.zip',
'speech_commands_conv_actions.zip']
diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/examples/android/jni/object_tracking/config.h
index 86e9fc71b6..47de2d2c15 100644
--- a/tensorflow/examples/android/jni/object_tracking/config.h
+++ b/tensorflow/examples/android/jni/object_tracking/config.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
#include <math.h>
@@ -297,4 +297,4 @@ struct TrackerConfig {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_CONFIG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
index 8813ab6d71..b62e334ecd 100644
--- a/tensorflow/examples/android/jni/object_tracking/flow_cache.h
+++ b/tensorflow/examples/android/jni/object_tracking/flow_cache.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
@@ -303,4 +303,4 @@ class FlowCache {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
index 8f409fe806..6c8ac9be98 100644
--- a/tensorflow/examples/android/jni/object_tracking/frame_pair.h
+++ b/tensorflow/examples/android/jni/object_tracking/frame_pair.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
@@ -100,4 +100,4 @@ class FramePair {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/examples/android/jni/object_tracking/geom.h
index 2819063616..c975e40144 100644
--- a/tensorflow/examples/android/jni/object_tracking/geom.h
+++ b/tensorflow/examples/android/jni/object_tracking/geom.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
#include "tensorflow/examples/android/jni/object_tracking/logging.h"
#include "tensorflow/examples/android/jni/object_tracking/utils.h"
@@ -316,4 +316,4 @@ inline BoundingSquare GetCenteredSquare(const BoundingBox& original_box) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
index bd5c233f4f..a29e677d3c 100755
--- a/tensorflow/examples/android/jni/object_tracking/gl_utils.h
+++ b/tensorflow/examples/android/jni/object_tracking/gl_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
#include <GLES/gl.h>
#include <GLES/glext.h>
@@ -52,4 +52,4 @@ inline static void MapWorldSquareToUnitSquare(const BoundingSquare& square) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GL_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/examples/android/jni/object_tracking/image-inl.h
index 9c4c389aa7..61d69908b5 100644
--- a/tensorflow/examples/android/jni/object_tracking/image-inl.h
+++ b/tensorflow/examples/android/jni/object_tracking/image-inl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
#include <stdint.h>
@@ -641,4 +641,4 @@ inline void Image<T>::FromArray(const T* const pixels, const int stride,
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_INL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/examples/android/jni/object_tracking/image.h
index b7a2301f5e..a436f0e0a1 100644
--- a/tensorflow/examples/android/jni/object_tracking/image.h
+++ b/tensorflow/examples/android/jni/object_tracking/image.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
#include <stdint.h>
@@ -338,4 +338,4 @@ inline std::ostream& operator<<(std::ostream& stream, const Image<t>& image) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/examples/android/jni/object_tracking/image_data.h
index 445cdb57a3..c4f91d8cbd 100644
--- a/tensorflow/examples/android/jni/object_tracking/image_data.h
+++ b/tensorflow/examples/android/jni/object_tracking/image_data.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
#include <stdint.h>
#include <memory>
@@ -261,4 +261,4 @@ class ImageData {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/examples/android/jni/object_tracking/image_utils.h
index ac9ffd90f8..b4ad7000b3 100644
--- a/tensorflow/examples/android/jni/object_tracking/image_utils.h
+++ b/tensorflow/examples/android/jni/object_tracking/image_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
#include <stdint.h>
@@ -295,4 +295,4 @@ inline void NormalizeImage(Image<float>* const image) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_UTILS_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/examples/android/jni/object_tracking/integral_image.h
index 8e82334abf..caf9b7d2ab 100755
--- a/tensorflow/examples/android/jni/object_tracking/integral_image.h
+++ b/tensorflow/examples/android/jni/object_tracking/integral_image.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
@@ -184,4 +184,4 @@ class IntegralImage : public Image<uint32_t> {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
index 21fbabb521..b81d9e0c12 100644
--- a/tensorflow/examples/android/jni/object_tracking/jni_utils.h
+++ b/tensorflow/examples/android/jni/object_tracking/jni_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_JNI_UTILS_H_
#include <stdint.h>
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/examples/android/jni/object_tracking/keypoint.h
index 719f9aff3f..93405a5b2a 100644
--- a/tensorflow/examples/android/jni/object_tracking/keypoint.h
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
@@ -45,4 +45,4 @@ inline std::ostream& operator<<(std::ostream& stream, const Keypoint keypoint) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
index 33d228128d..2e85b835a7 100644
--- a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
+++ b/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
#include <stdint.h>
#include <vector>
@@ -125,4 +125,4 @@ class KeypointDetector {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/logging.h b/tensorflow/examples/android/jni/object_tracking/logging.h
index dbc89af2f7..852a749399 100644
--- a/tensorflow/examples/android/jni/object_tracking/logging.h
+++ b/tensorflow/examples/android/jni/object_tracking/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
#include <android/log.h>
#include <string.h>
@@ -118,4 +118,4 @@ void LogPrintF(const int severity, const char* format, ...);
#endif
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_LOG_STREAMING_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/examples/android/jni/object_tracking/object_detector.h
index 2525567678..a65c7b0db7 100644
--- a/tensorflow/examples/android/jni/object_tracking/object_detector.h
+++ b/tensorflow/examples/android/jni/object_tracking/object_detector.h
@@ -20,8 +20,8 @@ limitations under the License.
// Defines the ObjectDetector class that is the main interface for detecting
// ObjectModelBases in frames.
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
#include <float.h>
#include <map>
@@ -227,4 +227,4 @@ class ObjectDetector : public ObjectDetectorBase {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/examples/android/jni/object_tracking/object_model.h
index be33aea638..5e81c49080 100644
--- a/tensorflow/examples/android/jni/object_tracking/object_model.h
+++ b/tensorflow/examples/android/jni/object_tracking/object_model.h
@@ -19,8 +19,8 @@ limitations under the License.
// Contains ObjectModelBase declaration.
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
#ifdef __RENDER_OPENGL__
#include <GLES/gl.h>
@@ -99,4 +99,4 @@ class ObjectModel : public ObjectModelBase {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_DETECTION_OBJECT_MODEL_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
index eb281fad37..20c7627fc5 100644
--- a/tensorflow/examples/android/jni/object_tracking/object_tracker.h
+++ b/tensorflow/examples/android/jni/object_tracking/object_tracker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
#include <map>
#include <string>
@@ -267,4 +267,4 @@ inline std::ostream& operator<<(std::ostream& stream,
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
index 2206375beb..f98ae22bd6 100644
--- a/tensorflow/examples/android/jni/object_tracking/optical_flow.h
+++ b/tensorflow/examples/android/jni/object_tracking/optical_flow.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
#include "tensorflow/examples/android/jni/object_tracking/geom.h"
#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
@@ -97,4 +97,4 @@ class OpticalFlow {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/examples/android/jni/object_tracking/sprite.h
index 05a13fea11..b54a68458f 100755
--- a/tensorflow/examples/android/jni/object_tracking/sprite.h
+++ b/tensorflow/examples/android/jni/object_tracking/sprite.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
#include <GLES/gl.h>
#include <GLES/glext.h>
@@ -199,4 +199,4 @@ class Sprite {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_SPRITE_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/examples/android/jni/object_tracking/time_log.h
index 60911da396..0073e11596 100644
--- a/tensorflow/examples/android/jni/object_tracking/time_log.h
+++ b/tensorflow/examples/android/jni/object_tracking/time_log.h
@@ -15,8 +15,8 @@ limitations under the License.
// Utility functions for performance profiling.
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
#include <stdint.h>
@@ -134,4 +134,4 @@ inline static void TimeLog(const char* const str) {
inline static void PrintTimeLog() {}
#endif
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TIME_LOG_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
index cda14e19d2..d7f1a7019b 100644
--- a/tensorflow/examples/android/jni/object_tracking/tracked_object.h
+++ b/tensorflow/examples/android/jni/object_tracking/tracked_object.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
#ifdef __RENDER_OPENGL__
#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h"
@@ -183,4 +183,4 @@ inline std::ostream& operator<<(std::ostream& stream,
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_
diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/examples/android/jni/object_tracking/utils.h
index 51cdfcdcfb..2e98734ec4 100644
--- a/tensorflow/examples/android/jni/object_tracking/utils.h
+++ b/tensorflow/examples/android/jni/object_tracking/utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+#ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+#define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
#include <math.h>
#include <stdint.h>
@@ -378,4 +378,4 @@ inline bool Invert2x2(const T* const a, float* const a_inv) {
} // namespace tf_tracking
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
+#endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_UTILS_H_
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java
index a317273acd..bc0c738e53 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java
@@ -81,8 +81,11 @@ public class LegacyCameraConnectionFragment extends Fragment {
try {
Camera.Parameters parameters = camera.getParameters();
- parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE);
-
+ List<String> focusModes = parameters.getSupportedFocusModes();
+ if (focusModes != null
+ && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) {
+ parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE);
+ }
List<Camera.Size> cameraSizes = parameters.getSupportedPreviewSizes();
Size[] sizes = new Size[cameraSizes.size()];
int i = 0;
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index fa4c1c0da5..461fb1c517 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Train and Eval the MNIST network.
This version is like fully_connected_feed.py but uses data converted
@@ -65,6 +64,7 @@ def decode(serialized_example):
return image, label
+
def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
@@ -72,12 +72,14 @@ def augment(image, label):
# into a vector, we don't bother.
return image, label
+
def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image, label
+
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
@@ -98,9 +100,10 @@ def inputs(train, batch_size, num_epochs):
over the dataset once. On the other hand there is no special initialization
required.
"""
- if not num_epochs: num_epochs = None
- filename = os.path.join(FLAGS.train_dir,
- TRAIN_FILE if train else VALIDATION_FILE)
+ if not num_epochs:
+ num_epochs = None
+ filename = os.path.join(FLAGS.train_dir, TRAIN_FILE
+ if train else VALIDATION_FILE)
with tf.name_scope('input'):
# TFRecordDataset opens a protobuf and reads entries line by line
@@ -127,13 +130,11 @@ def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
- image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
- num_epochs=FLAGS.num_epochs)
+ image_batch, label_batch = inputs(
+ train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(image_batch,
- FLAGS.hidden1,
- FLAGS.hidden2)
+ logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
# Add to the Graph the loss calculation.
loss = mnist.loss(logits, label_batch)
@@ -152,7 +153,7 @@ def run_training():
sess.run(init_op)
try:
step = 0
- while True: #train until OutOfRangeError
+ while True: #train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are
@@ -168,10 +169,12 @@ def run_training():
# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
- duration))
+ duration))
step += 1
except tf.errors.OutOfRangeError:
- print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
+ step))
+
def main(_):
run_training()
@@ -183,37 +186,27 @@ if __name__ == '__main__':
'--learning_rate',
type=float,
default=0.01,
- help='Initial learning rate.'
- )
+ help='Initial learning rate.')
parser.add_argument(
'--num_epochs',
type=int,
default=2,
- help='Number of epochs to run trainer.'
- )
+ help='Number of epochs to run trainer.')
parser.add_argument(
'--hidden1',
type=int,
default=128,
- help='Number of units in hidden layer 1.'
- )
+ help='Number of units in hidden layer 1.')
parser.add_argument(
'--hidden2',
type=int,
default=32,
- help='Number of units in hidden layer 2.'
- )
- parser.add_argument(
- '--batch_size',
- type=int,
- default=100,
- help='Batch size.'
- )
+ help='Number of units in hidden layer 2.')
+ parser.add_argument('--batch_size', type=int, default=100, help='Batch size.')
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/data',
- help='Directory with the training data.'
- )
+ help='Directory with the training data.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/label_image/label_image.py b/tensorflow/examples/label_image/label_image.py
index d62b73384c..1c1bd57d71 100644
--- a/tensorflow/examples/label_image/label_image.py
+++ b/tensorflow/examples/label_image/label_image.py
@@ -23,6 +23,7 @@ import sys
import numpy as np
import tensorflow as tf
+
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
@@ -34,22 +35,26 @@ def load_graph(model_file):
return graph
-def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
- input_mean=0, input_std=255):
+
+def read_tensor_from_image_file(file_name,
+ input_height=299,
+ input_width=299,
+ input_mean=0,
+ input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
- image_reader = tf.image.decode_png(file_reader, channels = 3,
- name='png_reader')
+ image_reader = tf.image.decode_png(
+ file_reader, channels=3, name="png_reader")
elif file_name.endswith(".gif"):
- image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
- name='gif_reader'))
+ image_reader = tf.squeeze(
+ tf.image.decode_gif(file_reader, name="gif_reader"))
elif file_name.endswith(".bmp"):
- image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
+ image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
else:
- image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
- name='jpeg_reader')
+ image_reader = tf.image.decode_jpeg(
+ file_reader, channels=3, name="jpeg_reader")
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
@@ -59,6 +64,7 @@ def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
return result
+
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
@@ -66,6 +72,7 @@ def load_labels(label_file):
label.append(l.rstrip())
return label
+
if __name__ == "__main__":
file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
model_file = \
@@ -110,11 +117,12 @@ if __name__ == "__main__":
output_layer = args.output_layer
graph = load_graph(model_file)
- t = read_tensor_from_image_file(file_name,
- input_height=input_height,
- input_width=input_width,
- input_mean=input_mean,
- input_std=input_std)
+ t = read_tensor_from_image_file(
+ file_name,
+ input_height=input_height,
+ input_width=input_width,
+ input_mean=input_mean,
+ input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
@@ -122,8 +130,9 @@ if __name__ == "__main__":
output_operation = graph.get_operation_by_name(output_name)
with tf.Session(graph=graph) as sess:
- results = sess.run(output_operation.outputs[0],
- {input_operation.outputs[0]: t})
+ results = sess.run(output_operation.outputs[0], {
+ input_operation.outputs[0]: t
+ })
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
diff --git a/tensorflow/examples/speech_commands/accuracy_utils.h b/tensorflow/examples/speech_commands/accuracy_utils.h
index 8d918cb64b..eea048365b 100644
--- a/tensorflow/examples/speech_commands/accuracy_utils.h
+++ b/tensorflow/examples/speech_commands/accuracy_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
+#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
+#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
#include <vector>
@@ -57,4 +57,4 @@ void PrintAccuracyStats(const StreamingAccuracyStats& stats);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
+#endif // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
diff --git a/tensorflow/examples/speech_commands/recognize_commands.h b/tensorflow/examples/speech_commands/recognize_commands.h
index 7f8041f9ed..a7cd194bec 100644
--- a/tensorflow/examples/speech_commands/recognize_commands.h
+++ b/tensorflow/examples/speech_commands/recognize_commands.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
+#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
+#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
#include <deque>
#include <unordered_set>
@@ -76,4 +76,4 @@ class RecognizeCommands {
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
+#endif // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_RECOGNIZE_COMMANDS_H_
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 87cd95165e..d055d15745 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import collections
import math
import os
+import sys
+import argparse
import random
from tempfile import gettempdir
import zipfile
@@ -30,6 +32,24 @@ from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.contrib.tensorboard.plugins import projector
+
+# Give a folder path as an argument with '--log_dir' to save
+# TensorBoard summaries. Default is a log folder in current directory.
+current_path = os.path.dirname(os.path.realpath(sys.argv[0]))
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ '--log_dir',
+ type=str,
+ default=os.path.join(current_path, 'log'),
+ help='The log directory for TensorBoard summaries.')
+FLAGS, unparsed = parser.parse_known_args()
+
+# Create the directory for TensorBoard variables if there is not.
+if not os.path.exists(FLAGS.log_dir):
+ os.makedirs(FLAGS.log_dir)
+
# Step 1: Download the data.
url = 'http://mattmahoney.net/dc/'
@@ -61,6 +81,7 @@ def read_data(filename):
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
+
vocabulary = read_data(filename)
print('Data size', len(vocabulary))
@@ -86,20 +107,22 @@ def build_dataset(words, n_words):
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reversed_dictionary
+
# Filling 4 global variables:
# data - list of codes (integers from 0 to vocabulary_size-1).
# This is the original text but words are replaced by their codes
# count - map of words(strings) to count of occurrences
# dictionary - map of words(strings) to their codes(integers)
# reverse_dictionary - maps codes(integers) to words(strings)
-data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
- vocabulary_size)
+data, count, dictionary, reverse_dictionary = build_dataset(
+ vocabulary, vocabulary_size)
del vocabulary # Hint to reduce memory.
print('Most common words (+UNK)', count[:5])
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])
data_index = 0
+
# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
global data_index
@@ -129,96 +152,136 @@ def generate_batch(batch_size, num_skips, skip_window):
data_index = (data_index + len(data) - span) % len(data)
return batch, labels
+
batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
- print(batch[i], reverse_dictionary[batch[i]],
- '->', labels[i, 0], reverse_dictionary[labels[i, 0]])
+ print(batch[i], reverse_dictionary[batch[i]], '->', labels[i, 0],
+ reverse_dictionary[labels[i, 0]])
# Step 4: Build and train a skip-gram model.
batch_size = 128
embedding_size = 128 # Dimension of the embedding vector.
-skip_window = 1 # How many words to consider left and right.
-num_skips = 2 # How many times to reuse an input to generate a label.
-num_sampled = 64 # Number of negative examples to sample.
+skip_window = 1 # How many words to consider left and right.
+num_skips = 2 # How many times to reuse an input to generate a label.
+num_sampled = 64 # Number of negative examples to sample.
# We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent. These 3 variables are used only for
# displaying model accuracy, they don't affect calculation.
-valid_size = 16 # Random set of words to evaluate similarity on.
+valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)
-
graph = tf.Graph()
with graph.as_default():
# Input data.
- train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
- train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
- valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
+ with tf.name_scope('inputs'):
+ train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
+ train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
+ valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
# Ops and variables pinned to the CPU because of missing GPU implementation
with tf.device('/cpu:0'):
# Look up embeddings for inputs.
- embeddings = tf.Variable(
- tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
- embed = tf.nn.embedding_lookup(embeddings, train_inputs)
+ with tf.name_scope('embeddings'):
+ embeddings = tf.Variable(
+ tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
+ embed = tf.nn.embedding_lookup(embeddings, train_inputs)
# Construct the variables for the NCE loss
- nce_weights = tf.Variable(
- tf.truncated_normal([vocabulary_size, embedding_size],
- stddev=1.0 / math.sqrt(embedding_size)))
- nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
+ with tf.name_scope('weights'):
+ nce_weights = tf.Variable(
+ tf.truncated_normal(
+ [vocabulary_size, embedding_size],
+ stddev=1.0 / math.sqrt(embedding_size)))
+ with tf.name_scope('biases'):
+ nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
# Compute the average NCE loss for the batch.
# tf.nce_loss automatically draws a new sample of the negative labels each
# time we evaluate the loss.
# Explanation of the meaning of NCE loss:
# http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/
- loss = tf.reduce_mean(
- tf.nn.nce_loss(weights=nce_weights,
- biases=nce_biases,
- labels=train_labels,
- inputs=embed,
- num_sampled=num_sampled,
- num_classes=vocabulary_size))
+ with tf.name_scope('loss'):
+ loss = tf.reduce_mean(
+ tf.nn.nce_loss(
+ weights=nce_weights,
+ biases=nce_biases,
+ labels=train_labels,
+ inputs=embed,
+ num_sampled=num_sampled,
+ num_classes=vocabulary_size))
+
+ # Add the loss value as a scalar to summary.
+ tf.summary.scalar('loss', loss)
# Construct the SGD optimizer using a learning rate of 1.0.
- optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
+ with tf.name_scope('optimizer'):
+ optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)
# Compute the cosine similarity between minibatch examples and all embeddings.
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
normalized_embeddings = embeddings / norm
- valid_embeddings = tf.nn.embedding_lookup(
- normalized_embeddings, valid_dataset)
+ valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings,
+ valid_dataset)
similarity = tf.matmul(
valid_embeddings, normalized_embeddings, transpose_b=True)
+ # Merge all summaries.
+ merged = tf.summary.merge_all()
+
# Add variable initializer.
init = tf.global_variables_initializer()
+ # Create a saver.
+ saver = tf.train.Saver()
+
# Step 5: Begin training.
num_steps = 100001
with tf.Session(graph=graph) as session:
+ # Open a writer to write summaries.
+ writer = tf.summary.FileWriter(FLAGS.log_dir, session.graph)
+
# We must initialize all variables before we use them.
init.run()
print('Initialized')
average_loss = 0
for step in xrange(num_steps):
- batch_inputs, batch_labels = generate_batch(
- batch_size, num_skips, skip_window)
+ batch_inputs, batch_labels = generate_batch(batch_size, num_skips,
+ skip_window)
feed_dict = {train_inputs: batch_inputs, train_labels: batch_labels}
+ # Define metadata variable.
+ run_metadata = tf.RunMetadata()
+
# We perform one update step by evaluating the optimizer op (including it
# in the list of returned values for session.run()
- _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
+ # Also, evaluate the merged op to get all summaries from the returned "summary" variable.
+ # Feed metadata variable to session for visualizing the graph in TensorBoard.
+ _, summary, loss_val = session.run(
+ [optimizer, merged, loss],
+ feed_dict=feed_dict,
+ run_metadata=run_metadata)
average_loss += loss_val
+ # Add returned summaries to writer in each step.
+ writer.add_summary(summary, step)
+ # Add metadata to visualize the graph for the last run.
+ if step == (num_steps - 1):
+ writer.add_run_metadata(run_metadata, 'step%d' % step)
+
+ # Add returned summaries to writer in each step.
+ writer.add_summary(summary, step)
+ # Add metadata to visualize the graph for the last run.
+ if step == (num_steps - 1):
+ writer.add_run_metadata(run_metadata, 'step%d' % step)
+
if step % 2000 == 0:
if step > 0:
average_loss /= 2000
@@ -240,6 +303,23 @@ with tf.Session(graph=graph) as session:
print(log_str)
final_embeddings = normalized_embeddings.eval()
+ # Write corresponding labels for the embeddings.
+ with open(FLAGS.log_dir + '/metadata.tsv', 'w') as f:
+ for i in xrange(vocabulary_size):
+ f.write(reverse_dictionary[i] + '\n')
+
+ # Save the model for checkpoints.
+ saver.save(session, os.path.join(FLAGS.log_dir, 'model.ckpt'))
+
+ # Create a configuration for visualizing embeddings with the labels in TensorBoard.
+ config = projector.ProjectorConfig()
+ embedding_conf = config.embeddings.add()
+ embedding_conf.tensor_name = embeddings.name
+ embedding_conf.metadata_path = os.path.join(FLAGS.log_dir, 'metadata.tsv')
+ projector.visualize_embeddings(writer, config)
+
+writer.close()
+
# Step 6: Visualize the embeddings.
@@ -251,21 +331,24 @@ def plot_with_labels(low_dim_embs, labels, filename):
for i, label in enumerate(labels):
x, y = low_dim_embs[i, :]
plt.scatter(x, y)
- plt.annotate(label,
- xy=(x, y),
- xytext=(5, 2),
- textcoords='offset points',
- ha='right',
- va='bottom')
+ plt.annotate(
+ label,
+ xy=(x, y),
+ xytext=(5, 2),
+ textcoords='offset points',
+ ha='right',
+ va='bottom')
plt.savefig(filename)
+
try:
# pylint: disable=g-import-not-at-top
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
- tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
+ tsne = TSNE(
+ perplexity=30, n_components=2, init='pca', n_iter=5000, method='exact')
plot_only = 500
low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only, :])
labels = [reverse_dictionary[i] for i in xrange(plot_only)]
diff --git a/tensorflow/examples/udacity/Dockerfile b/tensorflow/examples/udacity/Dockerfile
index 3ca58566c1..00eb853e52 100644
--- a/tensorflow/examples/udacity/Dockerfile
+++ b/tensorflow/examples/udacity/Dockerfile
@@ -8,7 +8,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN pip install scikit-learn pyreadline Pillow
+RUN pip install scikit-learn pyreadline Pillow imageio
RUN rm -rf /notebooks/*
ADD *.ipynb /notebooks/
WORKDIR /notebooks
diff --git a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h
index fa8cb0abe9..eada07e06f 100644
--- a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h
+++ b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
+#ifndef TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
+#define TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
@@ -28,4 +28,4 @@ tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav,
tensorflow::int32 stride, float brightness,
const tensorflow::string& output_image);
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
+#endif // TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 7bcc55959c..5b19c90238 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -116,6 +116,110 @@ func WriteImageSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Ou
return scope.AddOperation(opspec)
}
+// Outputs a `tf.Event` protocol buffer.
+//
+// When CreateSummaryDbWriter is being used, this op can be useful for
+// importing data from event logs.
+//
+// Arguments:
+// writer: A handle to a summary writer.
+// event: A string containing a binary-encoded tf.Event proto.
+//
+// Returns the created operation.
+func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ImportEvent",
+ Input: []tf.Input{
+ writer, event,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Outputs a `Summary` protocol buffer with a tensor.
+//
+// Arguments:
+// writer: A handle to a summary writer.
+// step: The step to write the summary for.
+// tensor: A tensor to serialize.
+// tag: The summary's tag.
+// summary_metadata: Serialized SummaryMetadata protocol buffer containing
+// plugin-related metadata for this summary.
+//
+// Returns the created operation.
+func WriteSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output, tag tf.Output, summary_metadata tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "WriteSummary",
+ Input: []tf.Input{
+ writer, step, tensor, tag, summary_metadata,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Creates summary database writer accessible by given resource handle.
+//
+// This can be used to write tensors from the execution graph directly
+// to a database. Only SQLite is supported right now. This function
+// will create the schema if it doesn't exist. Entries in the Users,
+// Experiments, and Runs tables will be created automatically if they
+// don't already exist.
+//
+// Arguments:
+// writer: Handle to SummaryWriter resource to overwrite.
+// db_uri: For example "file:/tmp/foo.sqlite".
+// experiment_name: Can't contain ASCII control characters or <>. Case
+// sensitive. If empty, then the Run will not be associated with any
+// Experiment.
+// run_name: Can't contain ASCII control characters or <>. Case sensitive.
+// If empty, then each Tag will not be associated with any Run.
+// user_name: Must be valid as both a DNS label and Linux username. If
+// empty, then the Experiment will not be associated with any User.
+//
+// Returns the created operation.
+func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "CreateSummaryDbWriter",
+ Input: []tf.Input{
+ writer, db_uri, experiment_name, run_name, user_name,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Creates a summary file writer accessible by the given resource handle.
+//
+// Arguments:
+// writer: A handle to the summary writer resource
+// logdir: Directory where the event file will be written.
+// max_queue: Size of the queue of pending events and summaries.
+// flush_millis: How often, in milliseconds, to flush the pending events and
+// summaries to disk.
+// filename_suffix: Every event file's name is suffixed with this suffix.
+//
+// Returns the created operation.
+func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, max_queue tf.Output, flush_millis tf.Output, filename_suffix tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "CreateSummaryFileWriter",
+ Input: []tf.Input{
+ writer, logdir, max_queue, flush_millis, filename_suffix,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Partitions `data` into `num_partitions` tensors using indices from `partitions`.
//
// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`
@@ -2357,6 +2461,8 @@ func TensorArrayV2TensorArrayName(value string) TensorArrayV2Attr {
}
// Deprecated. Use TensorArrayV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayV3
func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV2Attr) (handle tf.Output) {
if scope.Err() != nil {
return
@@ -3117,39 +3223,6 @@ func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output,
return op.Output(0)
}
-// Creates summary database writer accessible by given resource handle.
-//
-// This can be used to write tensors from the execution graph directly
-// to a database. Only SQLite is supported right now. This function
-// will create the schema if it doesn't exist. Entries in the Users,
-// Experiments, and Runs tables will be created automatically if they
-// don't already exist.
-//
-// Arguments:
-// writer: Handle to SummaryWriter resource to overwrite.
-// db_uri: For example "file:/tmp/foo.sqlite".
-// experiment_name: Can't contain ASCII control characters or <>. Case
-// sensitive. If empty, then the Run will not be associated with any
-// Experiment.
-// run_name: Can't contain ASCII control characters or <>. Case sensitive.
-// If empty, then each Tag will not be associated with any Run.
-// user_name: Must be valid as both a DNS label and Linux username. If
-// empty, then the Experiment will not be associated with any User.
-//
-// Returns the created operation.
-func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "CreateSummaryDbWriter",
- Input: []tf.Input{
- writer, db_uri, experiment_name, run_name, user_name,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Adds Tensor 'bias' to Tensor 'input' for Quantized types.
//
// Broadcasts the values of bias on dimensions 0..N-2 of 'input'.
@@ -5413,6 +5486,72 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f
return op.Output(0), op.Output(1), op.Output(2)
}
+// SummaryWriterAttr is an optional argument to SummaryWriter.
+type SummaryWriterAttr func(optionalAttr)
+
+// SummaryWriterSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func SummaryWriterSharedName(value string) SummaryWriterAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// SummaryWriterContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func SummaryWriterContainer(value string) SummaryWriterAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// Returns a handle to be used to access a summary writer.
+//
+// The summary writer is an in-graph resource which can be used by ops to write
+// summaries to event files.
+//
+// Returns the summary writer resource. Scalar handle.
+func SummaryWriter(scope *Scope, optional ...SummaryWriterAttr) (writer tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SummaryWriter",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes gradients for SparseSegmentMean.
+//
+// Returns tensor "output" with same shape as grad, except for dimension 0 whose
+// value is output_dim0.
+//
+// Arguments:
+// grad: gradient propagated to the SparseSegmentMean op.
+// indices: indices passed to the corresponding SparseSegmentMean op.
+// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
+func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMeanGrad",
+ Input: []tf.Input{
+ grad, indices, segment_ids, output_dim0,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -7427,30 +7566,6 @@ func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...Va
return op.Output(0)
}
-// Creates a summary file writer accessible by the given resource handle.
-//
-// Arguments:
-// writer: A handle to the summary writer resource
-// logdir: Directory where the event file will be written.
-// max_queue: Size of the queue of pending events and summaries.
-// flush_millis: How often, in milliseconds, to flush the pending events and
-// summaries to disk.
-// filename_suffix: Every event file's name is suffixed with this suffix.
-//
-// Returns the created operation.
-func CreateSummaryFileWriter(scope *Scope, writer tf.Output, logdir tf.Output, max_queue tf.Output, flush_millis tf.Output, filename_suffix tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "CreateSummaryFileWriter",
- Input: []tf.Input{
- writer, logdir, max_queue, flush_millis, filename_suffix,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Elementwise computes the bitwise XOR of `x` and `y`.
//
// The result will have those bits set, that are different in `x` and `y`. The
@@ -10353,6 +10468,8 @@ func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output,
}
// Deprecated. Use TensorArraySplitV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3
func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) {
if scope.Err() != nil {
return
@@ -10908,6 +11025,165 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s
return op.Output(0)
}
+// Flushes the writer's unwritten events.
+//
+// Arguments:
+// writer: A handle to the summary writer resource.
+//
+// Returns the created operation.
+func FlushSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "FlushSummaryWriter",
+ Input: []tf.Input{
+ writer,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// QuantizeV2Attr is an optional argument to QuantizeV2.
+type QuantizeV2Attr func(optionalAttr)
+
+// QuantizeV2Mode sets the optional mode attribute to value.
+// If not specified, defaults to "MIN_COMBINED"
+func QuantizeV2Mode(value string) QuantizeV2Attr {
+ return func(m optionalAttr) {
+ m["mode"] = value
+ }
+}
+
+// QuantizeV2RoundMode sets the optional round_mode attribute to value.
+// If not specified, defaults to "HALF_AWAY_FROM_ZERO"
+func QuantizeV2RoundMode(value string) QuantizeV2Attr {
+ return func(m optionalAttr) {
+ m["round_mode"] = value
+ }
+}
+
+// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'.
+//
+// [min_range, max_range] are scalar floats that specify the range for
+// the 'input' data. The 'mode' attribute controls exactly which calculations are
+// used to convert the float values to their quantized equivalents. The
+// 'round_mode' attribute controls which rounding tie-breaking algorithm is used
+// when rounding float values to their quantized equivalents.
+//
+// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following:
+//
+// ```
+// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range)
+// if T == qint8, out[i] -= (range(T) + 1) / 2.0
+// ```
+// here `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`
+//
+// *MIN_COMBINED Mode Example*
+//
+// Assume the input is type float and has a possible range of [0.0, 6.0] and the
+// output type is quint8 ([0, 255]). The min_range and max_range values should be
+// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each
+// value of the input by 255/6 and cast to quint8.
+//
+// If the output type was qint8 ([-128, 127]), the operation will additionally
+// subtract each value by 128 prior to casting, so that the range of values aligns
+// with the range of qint8.
+//
+// If the mode is 'MIN_FIRST', then this approach is used:
+//
+// ```
+// num_discrete_values = 1 << (# of bits in T)
+// range_adjust = num_discrete_values / (num_discrete_values - 1)
+// range = (range_max - range_min) * range_adjust
+// range_scale = num_discrete_values / range
+// quantized = round(input * range_scale) - round(range_min * range_scale) +
+// numeric_limits<T>::min()
+// quantized = max(quantized, numeric_limits<T>::min())
+// quantized = min(quantized, numeric_limits<T>::max())
+// ```
+//
+// The biggest difference between this and MIN_COMBINED is that the minimum range
+// is rounded first, before it's subtracted from the rounded value. With
+// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing
+// and dequantizing will introduce a larger and larger error.
+//
+// *SCALED mode Example*
+//
+// `SCALED` mode matches the quantization approach used in
+// `QuantizeAndDequantize{V2|V3}`.
+//
+// If the mode is `SCALED`, we do not use the full range of the output type,
+// choosing to elide the lowest possible value for symmetry (e.g., output range is
+// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to
+// 0.
+//
+// We first find the range of values in our tensor. The
+// range we use is always centered on 0, so we find m such that
+// ```c++
+// m = max(abs(input_min), abs(input_max))
+// ```
+//
+// Our input tensor range is then `[-m, m]`.
+//
+// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.
+// If T is signed, this is
+// ```
+// num_bits = sizeof(T) * 8
+// [min_fixed, max_fixed] =
+// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]
+// ```
+//
+// Otherwise, if T is unsigned, the fixed-point range is
+// ```
+// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]
+// ```
+//
+// From this we compute our scaling factor, s:
+// ```c++
+// s = (max_fixed - min_fixed) / (2 * m)
+// ```
+//
+// Now we can quantize the elements of our tensor:
+// ```c++
+// result = round(input * s)
+// ```
+//
+// One thing to watch out for is that the operator may choose to adjust the
+// requested minimum and maximum values slightly during the quantization process,
+// so you should always use the output ports as the range for further calculations.
+// For example, if the requested minimum and maximum values are close to equal,
+// they will be separated by a small epsilon value to prevent ill-formed quantized
+// buffers from being created. Otherwise, you can end up with buffers where all the
+// quantized values map to the same float value, which causes problems for
+// operations that have to perform further calculations on them.
+//
+// Arguments:
+//
+// min_range: The minimum scalar value possibly produced for the input.
+// max_range: The maximum scalar value possibly produced for the input.
+//
+//
+// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output.
+func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"T": T}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "QuantizeV2",
+ Input: []tf.Input{
+ input, min_range, max_range,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Component-wise divides a SparseTensor by a dense Tensor.
//
// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
@@ -11607,6 +11883,8 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output
}
// Deprecated. Use TensorArrayReadV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3
func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
if scope.Err() != nil {
return
@@ -14420,218 +14698,6 @@ func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...Ran
return op.Output(0)
}
-// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
-type AvgPool3DGradAttr func(optionalAttr)
-
-// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format of the input and output data. With the
-// default format "NDHWC", the data is stored in the order of:
-// [batch, in_depth, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCDHW", the data storage order is:
-// [batch, in_channels, in_depth, in_height, in_width].
-// If not specified, defaults to "NDHWC"
-func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes gradients of average pooling function.
-//
-// Arguments:
-// orig_input_shape: The original input dimensions.
-// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
-// ksize: 1-D tensor of length 5. The size of the window for each dimension of
-// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
-// strides: 1-D tensor of length 5. The stride of the sliding window for each
-// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
-// padding: The type of padding algorithm to use.
-//
-// Returns The backprop for input.
-func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AvgPool3DGrad",
- Input: []tf.Input{
- orig_input_shape, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample.
-type ParseSingleSequenceExampleAttr func(optionalAttr)
-
-// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
-//
-// value: A list of Ncontext_sparse types; the data types of data in
-// each context Feature given in context_sparse_keys.
-// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["context_sparse_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_dense_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
-//
-// value: A list of Ncontext_dense shapes; the shapes of data in
-// each context Feature given in context_dense_keys.
-// The number of elements in the Feature corresponding to context_dense_key[j]
-// must always equal context_dense_shapes[j].NumEntries().
-// The shape of context_dense_values[j] will match context_dense_shapes[j].
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["context_dense_shapes"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
-//
-// value: A list of Nfeature_list_sparse types; the data types
-// of data in each FeatureList given in feature_list_sparse_keys.
-// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_sparse_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
-//
-// value: A list of Nfeature_list_dense shapes; the shapes of
-// data in each FeatureList given in feature_list_dense_keys.
-// The shape of each Feature in the FeatureList corresponding to
-// feature_list_dense_key[j] must always equal
-// feature_list_dense_shapes[j].NumEntries().
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_dense_shapes"] = value
- }
-}
-
-// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors.
-//
-// Arguments:
-// serialized: A scalar containing a binary serialized SequenceExample proto.
-// feature_list_dense_missing_assumed_empty: A vector listing the
-// FeatureList keys which may be missing from the SequenceExample. If the
-// associated FeatureList is missing, it is treated as empty. By default,
-// any FeatureList not listed in this vector must exist in the SequenceExample.
-// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
-// The keys expected in the Examples' features associated with context_sparse
-// values.
-// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
-// The keys expected in the SequenceExamples' context features associated with
-// dense values.
-// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
-// (scalars). The keys expected in the FeatureLists associated with sparse
-// values.
-// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
-// The keys expected in the SequenceExamples' feature_lists associated
-// with lists of dense values.
-// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
-// context_dense_defaults[j] provides default values
-// when the SequenceExample's context map lacks context_dense_key[j].
-// If an empty Tensor is provided for context_dense_defaults[j],
-// then the Feature context_dense_keys[j] is required.
-// The input type is inferred from context_dense_defaults[j], even when it's
-// empty. If context_dense_defaults[j] is not empty, its shape must match
-// context_dense_shapes[j].
-// debug_name: A scalar containing the name of the serialized proto.
-// May contain, for example, table key (descriptive) name for the
-// corresponding serialized proto. This is purely useful for debugging
-// purposes, and the presence of values here has no effect on the output.
-// May also be an empty scalar if no name is available.
-func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ParseSingleSequenceExample",
- Input: []tf.Input{
- serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values
-}
-
// QuantizedConv2DAttr is an optional argument to QuantizedConv2D.
type QuantizedConv2DAttr func(optionalAttr)
@@ -18490,34 +18556,6 @@ func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
return scope.AddOperation(opspec)
}
-// Adjust the hue of one or more images.
-//
-// `images` is a tensor of at least 3 dimensions. The last dimension is
-// interpretted as channels, and must be three.
-//
-// The input image is considered in the RGB colorspace. Conceptually, the RGB
-// colors are first mapped into HSV. A delta is then applied all the hue values,
-// and then remapped back to RGB colorspace.
-//
-// Arguments:
-// images: Images to adjust. At least 3-D.
-// delta: A float delta to add to the hue.
-//
-// Returns The hue-adjusted image or images.
-func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustHue",
- Input: []tf.Input{
- images, delta,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam.
type ResourceApplyAdamAttr func(optionalAttr)
@@ -18625,72 +18663,6 @@ func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) {
return op.Output(0)
}
-// Computes gradients for SparseSegmentMean.
-//
-// Returns tensor "output" with same shape as grad, except for dimension 0 whose
-// value is output_dim0.
-//
-// Arguments:
-// grad: gradient propagated to the SparseSegmentMean op.
-// indices: indices passed to the corresponding SparseSegmentMean op.
-// segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
-// output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
-func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSegmentMeanGrad",
- Input: []tf.Input{
- grad, indices, segment_ids, output_dim0,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// SummaryWriterAttr is an optional argument to SummaryWriter.
-type SummaryWriterAttr func(optionalAttr)
-
-// SummaryWriterSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func SummaryWriterSharedName(value string) SummaryWriterAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// SummaryWriterContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func SummaryWriterContainer(value string) SummaryWriterAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// Returns a handle to be used to access a summary writer.
-//
-// The summary writer is an in-graph resource which can be used by ops to write
-// summaries to event files.
-//
-// Returns the summary writer resource. Scalar handle.
-func SummaryWriter(scope *Scope, optional ...SummaryWriterAttr) (writer tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SummaryWriter",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResizeBicubicGradAttr is an optional argument to ResizeBicubicGrad.
type ResizeBicubicGradAttr func(optionalAttr)
@@ -20245,6 +20217,8 @@ func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size
}
// Deprecated. Use TensorArrayGradV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayGradV3
func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) {
if scope.Err() != nil {
return
@@ -21538,6 +21512,8 @@ func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr {
}
// Deprecated. Use TensorArrayGatherV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3
func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) {
if scope.Err() != nil {
return
@@ -22262,6 +22238,8 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (
// Deprecated. Use TensorArrayCloseV3
//
+// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3
+//
// Returns the created operation.
func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
@@ -22381,6 +22359,69 @@ func Abs(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Flushes and closes the summary writer.
+//
+// Also removes it from the resource manager. To reopen, use another
+// CreateSummaryFileWriter op.
+//
+// Arguments:
+// writer: A handle to the summary writer resource.
+//
+// Returns the created operation.
+func CloseSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "CloseSummaryWriter",
+ Input: []tf.Input{
+ writer,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// StackV2Attr is an optional argument to StackV2.
+type StackV2Attr func(optionalAttr)
+
+// StackV2StackName sets the optional stack_name attribute to value.
+//
+// value: Overrides the name used for the temporary stack resource. Default
+// value is the name of the 'Stack' op (which is guaranteed unique).
+// If not specified, defaults to ""
+func StackV2StackName(value string) StackV2Attr {
+ return func(m optionalAttr) {
+ m["stack_name"] = value
+ }
+}
+
+// A stack that produces elements in first-in last-out order.
+//
+// Arguments:
+// max_size: The maximum size of the stack if non-negative. If negative, the stack
+// size is unlimited.
+// elem_type: The type of the elements on the stack.
+//
+// Returns The handle to the stack.
+func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"elem_type": elem_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StackV2",
+ Input: []tf.Input{
+ max_size,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// OrderedMapStageAttr is an optional argument to OrderedMapStage.
type OrderedMapStageAttr func(optionalAttr)
@@ -23218,6 +23259,8 @@ func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size
}
// Deprecated. Use TensorArrayGradV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3
func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
if scope.Err() != nil {
return
@@ -23368,6 +23411,8 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t
}
// Deprecated. Use TensorArrayScatterV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3
func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
if scope.Err() != nil {
return
@@ -23572,6 +23617,8 @@ func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, o
}
// Deprecated. Use TensorArraySizeV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3
func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) {
if scope.Err() != nil {
return
@@ -25440,6 +25487,246 @@ func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// Adjust the hue of one or more images.
+//
+// `images` is a tensor of at least 3 dimensions. The last dimension is
+// interpretted as channels, and must be three.
+//
+// The input image is considered in the RGB colorspace. Conceptually, the RGB
+// colors are first mapped into HSV. A delta is then applied all the hue values,
+// and then remapped back to RGB colorspace.
+//
+// Arguments:
+// images: Images to adjust. At least 3-D.
+// delta: A float delta to add to the hue.
+//
+// Returns The hue-adjusted image or images.
+func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustHue",
+ Input: []tf.Input{
+ images, delta,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
+type AvgPool3DGradAttr func(optionalAttr)
+
+// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format of the input and output data. With the
+// default format "NDHWC", the data is stored in the order of:
+// [batch, in_depth, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCDHW", the data storage order is:
+// [batch, in_channels, in_depth, in_height, in_width].
+// If not specified, defaults to "NDHWC"
+func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes gradients of average pooling function.
+//
+// Arguments:
+// orig_input_shape: The original input dimensions.
+// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+// ksize: 1-D tensor of length 5. The size of the window for each dimension of
+// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+// strides: 1-D tensor of length 5. The stride of the sliding window for each
+// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+// padding: The type of padding algorithm to use.
+//
+// Returns The backprop for input.
+func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPool3DGrad",
+ Input: []tf.Input{
+ orig_input_shape, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample.
+type ParseSingleSequenceExampleAttr func(optionalAttr)
+
+// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
+//
+// value: A list of Ncontext_sparse types; the data types of data in
+// each context Feature given in context_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_sparse_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
+//
+// value: A list of Ncontext_dense shapes; the shapes of data in
+// each context Feature given in context_dense_keys.
+// The number of elements in the Feature corresponding to context_dense_key[j]
+// must always equal context_dense_shapes[j].NumEntries().
+// The shape of context_dense_values[j] will match context_dense_shapes[j].
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_dense_shapes"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
+//
+// value: A list of Nfeature_list_sparse types; the data types
+// of data in each FeatureList given in feature_list_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_sparse_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
+//
+// value: A list of Nfeature_list_dense shapes; the shapes of
+// data in each FeatureList given in feature_list_dense_keys.
+// The shape of each Feature in the FeatureList corresponding to
+// feature_list_dense_key[j] must always equal
+// feature_list_dense_shapes[j].NumEntries().
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_shapes"] = value
+ }
+}
+
+// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors.
+//
+// Arguments:
+// serialized: A scalar containing a binary serialized SequenceExample proto.
+// feature_list_dense_missing_assumed_empty: A vector listing the
+// FeatureList keys which may be missing from the SequenceExample. If the
+// associated FeatureList is missing, it is treated as empty. By default,
+// any FeatureList not listed in this vector must exist in the SequenceExample.
+// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
+// The keys expected in the Examples' features associated with context_sparse
+// values.
+// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' context features associated with
+// dense values.
+// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+// (scalars). The keys expected in the FeatureLists associated with sparse
+// values.
+// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' feature_lists associated
+// with lists of dense values.
+// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
+// context_dense_defaults[j] provides default values
+// when the SequenceExample's context map lacks context_dense_key[j].
+// If an empty Tensor is provided for context_dense_defaults[j],
+// then the Feature context_dense_keys[j] is required.
+// The input type is inferred from context_dense_defaults[j], even when it's
+// empty. If context_dense_defaults[j] is not empty, its shape must match
+// context_dense_shapes[j].
+// debug_name: A scalar containing the name of the serialized proto.
+// May contain, for example, table key (descriptive) name for the
+// corresponding serialized proto. This is purely useful for debugging
+// purposes, and the presence of values here has no effect on the output.
+// May also be an empty scalar if no name is available.
+func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ParseSingleSequenceExample",
+ Input: []tf.Input{
+ serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values
+}
+
// DecodeWavAttr is an optional argument to DecodeWav.
type DecodeWavAttr func(optionalAttr)
@@ -28068,272 +28355,3 @@ func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1), op.Output(2)
}
-
-// QuantizeV2Attr is an optional argument to QuantizeV2.
-type QuantizeV2Attr func(optionalAttr)
-
-// QuantizeV2Mode sets the optional mode attribute to value.
-// If not specified, defaults to "MIN_COMBINED"
-func QuantizeV2Mode(value string) QuantizeV2Attr {
- return func(m optionalAttr) {
- m["mode"] = value
- }
-}
-
-// QuantizeV2RoundMode sets the optional round_mode attribute to value.
-// If not specified, defaults to "HALF_AWAY_FROM_ZERO"
-func QuantizeV2RoundMode(value string) QuantizeV2Attr {
- return func(m optionalAttr) {
- m["round_mode"] = value
- }
-}
-
-// Quantize the 'input' tensor of type float to 'output' tensor of type 'T'.
-//
-// [min_range, max_range] are scalar floats that specify the range for
-// the 'input' data. The 'mode' attribute controls exactly which calculations are
-// used to convert the float values to their quantized equivalents. The
-// 'round_mode' attribute controls which rounding tie-breaking algorithm is used
-// when rounding float values to their quantized equivalents.
-//
-// In 'MIN_COMBINED' mode, each value of the tensor will undergo the following:
-//
-// ```
-// out[i] = (in[i] - min_range) * range(T) / (max_range - min_range)
-// if T == qint8, out[i] -= (range(T) + 1) / 2.0
-// ```
-// here `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`
-//
-// *MIN_COMBINED Mode Example*
-//
-// Assume the input is type float and has a possible range of [0.0, 6.0] and the
-// output type is quint8 ([0, 255]). The min_range and max_range values should be
-// specified as 0.0 and 6.0. Quantizing from float to quint8 will multiply each
-// value of the input by 255/6 and cast to quint8.
-//
-// If the output type was qint8 ([-128, 127]), the operation will additionally
-// subtract each value by 128 prior to casting, so that the range of values aligns
-// with the range of qint8.
-//
-// If the mode is 'MIN_FIRST', then this approach is used:
-//
-// ```
-// num_discrete_values = 1 << (# of bits in T)
-// range_adjust = num_discrete_values / (num_discrete_values - 1)
-// range = (range_max - range_min) * range_adjust
-// range_scale = num_discrete_values / range
-// quantized = round(input * range_scale) - round(range_min * range_scale) +
-// numeric_limits<T>::min()
-// quantized = max(quantized, numeric_limits<T>::min())
-// quantized = min(quantized, numeric_limits<T>::max())
-// ```
-//
-// The biggest difference between this and MIN_COMBINED is that the minimum range
-// is rounded first, before it's subtracted from the rounded value. With
-// MIN_COMBINED, a small bias is introduced where repeated iterations of quantizing
-// and dequantizing will introduce a larger and larger error.
-//
-// *SCALED mode Example*
-//
-// `SCALED` mode matches the quantization approach used in
-// `QuantizeAndDequantize{V2|V3}`.
-//
-// If the mode is `SCALED`, we do not use the full range of the output type,
-// choosing to elide the lowest possible value for symmetry (e.g., output range is
-// -127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to
-// 0.
-//
-// We first find the range of values in our tensor. The
-// range we use is always centered on 0, so we find m such that
-// ```c++
-// m = max(abs(input_min), abs(input_max))
-// ```
-//
-// Our input tensor range is then `[-m, m]`.
-//
-// Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.
-// If T is signed, this is
-// ```
-// num_bits = sizeof(T) * 8
-// [min_fixed, max_fixed] =
-// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]
-// ```
-//
-// Otherwise, if T is unsigned, the fixed-point range is
-// ```
-// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]
-// ```
-//
-// From this we compute our scaling factor, s:
-// ```c++
-// s = (max_fixed - min_fixed) / (2 * m)
-// ```
-//
-// Now we can quantize the elements of our tensor:
-// ```c++
-// result = round(input * s)
-// ```
-//
-// One thing to watch out for is that the operator may choose to adjust the
-// requested minimum and maximum values slightly during the quantization process,
-// so you should always use the output ports as the range for further calculations.
-// For example, if the requested minimum and maximum values are close to equal,
-// they will be separated by a small epsilon value to prevent ill-formed quantized
-// buffers from being created. Otherwise, you can end up with buffers where all the
-// quantized values map to the same float value, which causes problems for
-// operations that have to perform further calculations on them.
-//
-// Arguments:
-//
-// min_range: The minimum scalar value possibly produced for the input.
-// max_range: The maximum scalar value possibly produced for the input.
-//
-//
-// Returns The quantized data produced from the float input.The actual minimum scalar value used for the output.The actual maximum scalar value used for the output.
-func QuantizeV2(scope *Scope, input tf.Output, min_range tf.Output, max_range tf.Output, T tf.DataType, optional ...QuantizeV2Attr) (output tf.Output, output_min tf.Output, output_max tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"T": T}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QuantizeV2",
- Input: []tf.Input{
- input, min_range, max_range,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// Flushes the writer's unwritten events.
-//
-// Arguments:
-// writer: A handle to the summary writer resource.
-//
-// Returns the created operation.
-func FlushSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "FlushSummaryWriter",
- Input: []tf.Input{
- writer,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// StackV2Attr is an optional argument to StackV2.
-type StackV2Attr func(optionalAttr)
-
-// StackV2StackName sets the optional stack_name attribute to value.
-//
-// value: Overrides the name used for the temporary stack resource. Default
-// value is the name of the 'Stack' op (which is guaranteed unique).
-// If not specified, defaults to ""
-func StackV2StackName(value string) StackV2Attr {
- return func(m optionalAttr) {
- m["stack_name"] = value
- }
-}
-
-// A stack that produces elements in first-in last-out order.
-//
-// Arguments:
-// max_size: The maximum size of the stack if non-negative. If negative, the stack
-// size is unlimited.
-// elem_type: The type of the elements on the stack.
-//
-// Returns The handle to the stack.
-func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"elem_type": elem_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StackV2",
- Input: []tf.Input{
- max_size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Flushes and closes the summary writer.
-//
-// Also removes it from the resource manager. To reopen, use another
-// CreateSummaryFileWriter op.
-//
-// Arguments:
-// writer: A handle to the summary writer resource.
-//
-// Returns the created operation.
-func CloseSummaryWriter(scope *Scope, writer tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "CloseSummaryWriter",
- Input: []tf.Input{
- writer,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// Outputs a `Summary` protocol buffer with a tensor.
-//
-// Arguments:
-// writer: A handle to a summary writer.
-// step: The step to write the summary for.
-// tensor: A tensor to serialize.
-// tag: The summary's tag.
-// summary_metadata: Serialized SummaryMetadata protocol buffer containing
-// plugin-related metadata for this summary.
-//
-// Returns the created operation.
-func WriteSummary(scope *Scope, writer tf.Output, step tf.Output, tensor tf.Output, tag tf.Output, summary_metadata tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "WriteSummary",
- Input: []tf.Input{
- writer, step, tensor, tag, summary_metadata,
- },
- }
- return scope.AddOperation(opspec)
-}
-
-// Outputs a `tf.Event` protocol buffer.
-//
-// When CreateSummaryDbWriter is being used, this op can be useful for
-// importing data from event logs.
-//
-// Arguments:
-// writer: A handle to a summary writer.
-// event: A string containing a binary-encoded tf.Event proto.
-//
-// Returns the created operation.
-func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ImportEvent",
- Input: []tf.Input{
- writer, event,
- },
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index dbb29d9878..01b3e92d2d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1388,6 +1388,13 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "batch_ops_gen",
+ visibility = [
+ "//tensorflow:__subpackages__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
visibility = [
"//learning/brain/google/python/ops:__pkg__",
@@ -1951,6 +1958,7 @@ py_library(
srcs = ["ops/list_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":array_ops",
":list_ops_gen",
],
)
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 1481a4d035..e6f94396b8 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""A client interface for TensorFlow."""
from __future__ import absolute_import
@@ -71,8 +70,9 @@ def _get_indexed_slices_value_from_fetches(fetched_vals):
def _get_feeds_for_indexed_slices(feed, feed_val):
- return list(zip([feed.values, feed.indices] if feed.dense_shape is None else
- [feed.values, feed.indices, feed.dense_shape], feed_val))
+ return list(
+ zip([feed.values, feed.indices] if feed.dense_shape is None else
+ [feed.values, feed.indices, feed.dense_shape], feed_val))
# List of extensions supported to convert run arguments into actual fetches and
@@ -124,6 +124,7 @@ _REGISTERED_EXPANSIONS = [
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
lambda feed, feed_val: [(feed, feed_val)],
lambda feed: [feed])]
+
# pylint: enable=g-long-lambda
@@ -132,8 +133,11 @@ def _convert_to_numpy_obj(numpy_dtype, obj):
return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
-def register_session_run_conversion_functions(tensor_type, fetch_function,
- feed_function=None, feed_function_for_partial_run=None):
+def register_session_run_conversion_functions(
+ tensor_type,
+ fetch_function,
+ feed_function=None,
+ feed_function_for_partial_run=None):
"""Register fetch and feed conversion functions for `tf.Session.run()`.
This function registers a triple of conversion functions for fetching and/or
@@ -174,11 +178,11 @@ def register_session_run_conversion_functions(tensor_type, fetch_function,
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError(
- '%s has already been registered so ignore it.', tensor_type)
+ raise ValueError('%s has already been registered so ignore it.',
+ tensor_type)
return
- _REGISTERED_EXPANSIONS.insert(0,
- (tensor_type, fetch_function, feed_function, feed_function_for_partial_run))
+ _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
+ feed_function_for_partial_run))
class _FetchMapper(object):
@@ -233,8 +237,8 @@ class _FetchMapper(object):
An instance of a subclass of `_FetchMapper` that handles the shape.
"""
if fetch is None:
- raise TypeError('Fetch argument %r has invalid type %r' %
- (fetch, type(fetch)))
+ raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
+ type(fetch)))
elif isinstance(fetch, (list, tuple)):
# NOTE(touts): This is also the code path for namedtuples.
return _ListFetchMapper(fetch)
@@ -247,8 +251,8 @@ class _FetchMapper(object):
fetches, contraction_fn = fetch_fn(fetch)
return _ElementFetchMapper(fetches, contraction_fn)
# Did not find anything.
- raise TypeError('Fetch argument %r has invalid type %r' %
- (fetch, type(fetch)))
+ raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
+ type(fetch)))
class _ElementFetchMapper(_FetchMapper):
@@ -277,8 +281,8 @@ class _ElementFetchMapper(_FetchMapper):
fetch, allow_tensor=True, allow_operation=True))
except TypeError as e:
raise TypeError('Fetch argument %r has invalid type %r, '
- 'must be a string or Tensor. (%s)'
- % (fetch, type(fetch), str(e)))
+ 'must be a string or Tensor. (%s)' %
+ (fetch, type(fetch), str(e)))
except ValueError as e:
raise ValueError('Fetch argument %r cannot be interpreted as a '
'Tensor. (%s)' % (fetch, str(e)))
@@ -376,8 +380,9 @@ class _DictFetchMapper(_FetchMapper):
"""
self._fetch_type = type(fetches)
self._keys = fetches.keys()
- self._mappers = [_FetchMapper.for_fetch(fetch)
- for fetch in fetches.values()]
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
+ ]
self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
def unique_fetches(self):
@@ -401,6 +406,7 @@ class _FetchHandler(object):
result structure matching the user-provided structure for fetches, but
containing the corresponding results.
"""
+
# TODO(touts): Make this class also take care of destructuring the feed
# dict instead of doing it in the callers.
@@ -551,8 +557,11 @@ class _DeviceAttributes(object):
return self._memory_limit_bytes
def __repr__(self):
- return '_DeviceAttributes(%s, %s, %d)' % (self.name, self.device_type,
- self.memory_limit_bytes,)
+ return '_DeviceAttributes(%s, %s, %d)' % (
+ self.name,
+ self.device_type,
+ self.memory_limit_bytes,
+ )
class BaseSession(SessionInterface):
@@ -601,8 +610,8 @@ class BaseSession(SessionInterface):
if config is not None:
if not isinstance(config, config_pb2.ConfigProto):
- raise TypeError('config must be a tf.ConfigProto, but got %s'
- % type(config))
+ raise TypeError(
+ 'config must be a tf.ConfigProto, but got %s' % type(config))
self._config = config
self._add_shapes = config.graph_options.infer_shapes
else:
@@ -976,8 +985,8 @@ class BaseSession(SessionInterface):
for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
if isinstance(feed, tensor_type):
return feed_fn(feed)
- raise TypeError('Feed argument %r has invalid type %r'
- % (feed, type(feed)))
+ raise TypeError('Feed argument %r has invalid type %r' % (feed,
+ type(feed)))
# Check session.
if self._closed:
@@ -998,8 +1007,8 @@ class BaseSession(SessionInterface):
for feed in feeds:
for subfeed in _feed_fn(feed):
try:
- subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
- allow_operation=False)
+ subfeed_t = self.graph.as_graph_element(
+ subfeed, allow_tensor=True, allow_operation=False)
if self._created_with_new_api:
# pylint: disable=protected-access
feed_list.append(subfeed_t._as_tf_output())
@@ -1007,8 +1016,7 @@ class BaseSession(SessionInterface):
else:
feed_list.append(compat.as_bytes(subfeed_t.name))
except Exception as e:
- e.message = ('Cannot interpret feed_list key as Tensor: '
- + e.message)
+ e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
e.args = (e.message,)
raise e
@@ -1041,12 +1049,13 @@ class BaseSession(SessionInterface):
def _run(self, handle, fetches, feed_dict, options, run_metadata):
"""Perform either run or partial_run, depending the presence of `handle`."""
+
def _feed_fn(feed, feed_val):
for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
if isinstance(feed, tensor_type):
return feed_fn(feed, feed_val)
- raise TypeError('Feed argument %r has invalid type %r'
- % (feed, type(feed)))
+ raise TypeError('Feed argument %r has invalid type %r' % (feed,
+ type(feed)))
# Check session.
if self._closed:
@@ -1066,11 +1075,11 @@ class BaseSession(SessionInterface):
for feed, feed_val in feed_dict.items():
for subfeed, subfeed_val in _feed_fn(feed, feed_val):
try:
- subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
- allow_operation=False)
+ subfeed_t = self.graph.as_graph_element(
+ subfeed, allow_tensor=True, allow_operation=False)
except Exception as e:
- raise TypeError('Cannot interpret feed_dict key as Tensor: '
- + e.args[0])
+ raise TypeError(
+ 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
if isinstance(subfeed_val, ops.Tensor):
raise TypeError('The value of a feed cannot be a tf.Tensor object. '
@@ -1081,10 +1090,9 @@ class BaseSession(SessionInterface):
if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
subfeed_dtype, subfeed_val) != subfeed_val:
raise TypeError(
- 'Type of feed value ' + str(subfeed_val) + ' with type ' +
- str(type(subfeed_val)) +
- ' is not compatible with Tensor type ' +
- str(subfeed_dtype) +
+ 'Type of feed value ' + str(subfeed_val) + ' with type ' + str(
+ type(subfeed_val)) +
+ ' is not compatible with Tensor type ' + str(subfeed_dtype) +
'. Try explicitly setting the type of the feed tensor'
' to a larger type (e.g. int64).')
@@ -1098,10 +1106,10 @@ class BaseSession(SessionInterface):
if (not is_tensor_handle_feed and
not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
- raise ValueError(
- 'Cannot feed value of shape %r for Tensor %r, '
- 'which has shape %r'
- % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
+ raise ValueError('Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r' %
+ (np_val.shape, subfeed_t.name,
+ str(subfeed_t.get_shape())))
if not self.graph.is_feedable(subfeed_t):
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
@@ -1130,10 +1138,7 @@ class BaseSession(SessionInterface):
results = []
return fetch_handler.build_results(self, results)
- def make_callable(self,
- fetches,
- feed_list=None,
- accept_options=False):
+ def make_callable(self, fetches, feed_list=None, accept_options=False):
"""Returns a Python callable that runs a particular step.
The returned callable will take `len(feed_list)` arguments whose types
@@ -1176,9 +1181,12 @@ class BaseSession(SessionInterface):
# `Session._run()` so that we can convert the feeds to a list of
# strings here.
def _generic_run(*feed_args, **kwargs):
- feed_dict = {feed: feed_val
- for feed, feed_val in zip(feed_list, feed_args)}
+ feed_dict = {
+ feed: feed_val
+ for feed, feed_val in zip(feed_list, feed_args)
+ }
return self.run(fetches, feed_dict=feed_dict, **kwargs)
+
return _generic_run
# Ensure any changes to the graph are reflected in the runtime.
@@ -1198,12 +1206,11 @@ class BaseSession(SessionInterface):
fetch_list = _name_list(fetch_handler.fetches())
target_list = _name_list(fetch_handler.targets())
- def _callable_template_with_options_and_metadata(
- fetch_list,
- target_list,
- fetch_handler,
- options=None,
- run_metadata=None):
+ def _callable_template_with_options_and_metadata(fetch_list,
+ target_list,
+ fetch_handler,
+ options=None,
+ run_metadata=None):
"""Template callable that accepts RunOptions and RunMetadata."""
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(options.SerializeToString())) if options else None
@@ -1215,9 +1222,9 @@ class BaseSession(SessionInterface):
self._session, options_ptr, {}, fetch_list, target_list,
run_metadata_ptr, status)
else:
- results = tf_session.TF_Run(
- self._session, options_ptr, {}, fetch_list, target_list, status,
- run_metadata_ptr)
+ results = tf_session.TF_Run(self._session, options_ptr, {},
+ fetch_list, target_list, status,
+ run_metadata_ptr)
if fetch_handler:
results = fetch_handler.build_results(self, results)
else:
@@ -1233,37 +1240,40 @@ class BaseSession(SessionInterface):
return results
if accept_options:
- return functools.partial(
- _callable_template_with_options_and_metadata, fetch_list,
- target_list, fetch_handler)
+ return functools.partial(_callable_template_with_options_and_metadata,
+ fetch_list, target_list, fetch_handler)
elif isinstance(fetches, ops.Operation):
# Special case for fetching a single operation, because the
# function will have no return value.
assert not fetch_list
assert len(target_list) == 1
+
def _single_operation_run():
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
- tf_session.TF_SessionRun_wrapper(
- self._session, None, {}, [], target_list, None, status)
+ tf_session.TF_SessionRun_wrapper(self._session, None, {}, [],
+ target_list, None, status)
else:
- tf_session.TF_Run(
- self._session, None, {}, [], target_list, status, None)
+ tf_session.TF_Run(self._session, None, {}, [], target_list, status,
+ None)
+
return _single_operation_run
elif isinstance(fetches, ops.Tensor):
# Special case for fetching a single tensor, because the
# function can return the result of `TF_Run()` directly.
assert len(fetch_list) == 1
assert not target_list
+
def _single_tensor_run():
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
results = tf_session.TF_SessionRun_wrapper(
self._session, None, {}, fetch_list, [], None, status)
else:
- results = tf_session.TF_Run(
- self._session, None, {}, fetch_list, [], status, None)
+ results = tf_session.TF_Run(self._session, None, {}, fetch_list, [],
+ status, None)
return results[0]
+
return _single_tensor_run
else:
# In all other cases, we must use `fetch_handler` to build the
@@ -1274,16 +1284,17 @@ class BaseSession(SessionInterface):
results = tf_session.TF_SessionRun_wrapper(
self._session, None, {}, fetch_list, target_list, None, status)
else:
- results = tf_session.TF_Run(
- self._session, None, {}, fetch_list, target_list, status, None)
+ results = tf_session.TF_Run(self._session, None, {}, fetch_list,
+ target_list, status, None)
return fetch_handler.build_results(self, results)
+
return _fetch_handler_run
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
- def _do_run(self, handle, target_list, fetch_list, feed_dict,
- options, run_metadata):
+ def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
+ run_metadata):
"""Runs a step based on the given fetches and feeds.
Args:
@@ -1320,13 +1331,12 @@ class BaseSession(SessionInterface):
self._extend_graph()
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- session, options, feed_dict, fetch_list, target_list,
- run_metadata, status)
+ return tf_session.TF_SessionRun_wrapper(session, options, feed_dict,
+ fetch_list, target_list,
+ run_metadata, status)
else:
- return tf_session.TF_Run(session, options,
- feed_dict, fetch_list, target_list,
- status, run_metadata)
+ return tf_session.TF_Run(session, options, feed_dict, fetch_list,
+ target_list, status, run_metadata)
def _prun_fn(session, handle, feed_dict, fetch_list):
if target_list:
@@ -1365,20 +1375,20 @@ class BaseSession(SessionInterface):
def _extend_graph(self):
# Nothing to do if we're using the new session interface
# TODO(skyewm): remove this function altogether eventually
- if self._created_with_new_api: return
+ if self._created_with_new_api:
+ return
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
if self._graph.version > self._current_version:
# pylint: disable=protected-access
graph_def, self._current_version = self._graph._as_graph_def(
- from_version=self._current_version,
- add_shapes=self._add_shapes)
+ from_version=self._current_version, add_shapes=self._add_shapes)
# pylint: enable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_ExtendGraph(
- self._session, graph_def.SerializeToString(), status)
+ tf_session.TF_ExtendGraph(self._session,
+ graph_def.SerializeToString(), status)
self._opened = True
# The threshold to run garbage collection to delete dead tensors.
@@ -1398,9 +1408,8 @@ class BaseSession(SessionInterface):
feeds = {}
fetches = []
for deleter_key, tensor_handle in enumerate(tensors_to_delete):
- holder, deleter = session_ops._get_handle_deleter(self.graph,
- deleter_key,
- tensor_handle)
+ holder, deleter = session_ops._get_handle_deleter(
+ self.graph, deleter_key, tensor_handle)
feeds[holder] = tensor_handle
fetches.append(deleter)
self.run(fetches, feed_dict=feeds)
@@ -1471,7 +1480,8 @@ class Session(BaseSession):
sess.run(...)
```
- The [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
+ The
+ [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
protocol buffer exposes various configuration options for a
session. For example, to create a session that uses soft constraints
for device placement, and log the resulting placement decisions,
@@ -1502,7 +1512,8 @@ class Session(BaseSession):
@{$distributed$Distributed TensorFlow}
for more examples.
graph: (Optional.) The `Graph` to be launched (described above).
- config: (Optional.) A [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
+ config: (Optional.) A
+ [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
protocol buffer with configuration options for the session.
"""
@@ -1526,8 +1537,8 @@ class Session(BaseSession):
def __exit__(self, exec_type, exec_value, exec_tb):
if exec_type is errors.OpError:
logging.error('Session closing due to OpError: %s', (exec_value,))
- self._default_session_context_manager.__exit__(
- exec_type, exec_value, exec_tb)
+ self._default_session_context_manager.__exit__(exec_type, exec_value,
+ exec_tb)
self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
self._default_session_context_manager = None
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c579fba339..768a5db88a 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for tensorflow.python.client.session.Session."""
from __future__ import absolute_import
from __future__ import division
@@ -57,7 +56,6 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
-
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -95,14 +93,18 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(arr, copy_val)
# Test without feed.
copy_val = copy.eval()
- self.assertAllEqual(np.asarray([[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]],
- dtype=np.float32), copy_val)
+ self.assertAllEqual(
+ np.asarray(
+ [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
+ copy_val)
def testManyCPUs(self):
# TODO(keveman): Implement ListDevices and test for the number of
# devices returned by ListDevices.
with session.Session(
- config=config_pb2.ConfigProto(device_count={'CPU': 2})):
+ config=config_pb2.ConfigProto(device_count={
+ 'CPU': 2
+ })):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
@@ -161,20 +163,23 @@ class SessionTest(test_util.TensorFlowTestCase):
def exc_predicate(e):
return (e.op is None and e.node_def is None and
e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+
with self.assertRaisesOpError(exc_predicate):
# Run with a bogus handle.
s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
def testOpConstructionErrorPayload(self):
- if ops._USE_C_API: return # No shape registration for 'ConstructionFails'
+ if ops._USE_C_API:
+ return # No shape registration for 'ConstructionFails'
with session.Session():
failing_op = ops.get_default_graph().create_op(
'ConstructionFails', [], [], name='f')
def exc_predicate(e):
- return (e.op == failing_op
- and e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ return (e.op == failing_op and
+ e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+
with self.assertRaisesOpError(exc_predicate):
failing_op.run()
@@ -191,9 +196,9 @@ class SessionTest(test_util.TensorFlowTestCase):
# pylint: enable=protected-access
def exc_predicate(e):
- return (e.op == c.op
- and e.op._original_op == b.op
- and e.op._original_op._original_op == a.op)
+ return (e.op == c.op and e.op._original_op == b.op and
+ e.op._original_op._original_op == a.op)
+
with self.assertRaisesOpError(exc_predicate):
c.eval()
@@ -341,8 +346,12 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(c_val)
# List of lists, tuples, namedtuple, and dict
- res = sess.run([[a, b, c], (a, b, c), ABC(a=a, b=b, c=c),
- {'a': a.name, 'c': c, 'b': b}])
+ res = sess.run([[a, b, c], (a, b, c),
+ ABC(a=a, b=b, c=c), {
+ 'a': a.name,
+ 'c': c,
+ 'b': b
+ }])
self.assertTrue(isinstance(res, list))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
@@ -365,8 +374,11 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Tuple of lists, tuples, namedtuple, and dict
- res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c),
- {'a': a, 'c': c, 'b': b}))
+ res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
+ 'a': a,
+ 'c': c,
+ 'b': b
+ }))
self.assertTrue(isinstance(res, tuple))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
@@ -389,10 +401,16 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Namedtuple of lists, tuples, namedtuples, and dict
- res = sess.run(DEFG(d=[a, b, c],
- e=(a, b, c),
- f=ABC(a=a.name, b=b, c=c),
- g={'a': a, 'c': c, 'b': b}))
+ res = sess.run(
+ DEFG(
+ d=[a, b, c],
+ e=(a, b, c),
+ f=ABC(a=a.name, b=b, c=c),
+ g={
+ 'a': a,
+ 'c': c,
+ 'b': b
+ }))
self.assertTrue(isinstance(res, DEFG))
self.assertTrue(isinstance(res.d, list))
self.assertEqual(3, len(res.d))
@@ -414,10 +432,16 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res.g['b'])
self.assertEqual(c_val, res.g['c'])
# Dict of lists, tuples, namedtuples, and dict
- res = sess.run({'d': [a, b, c],
- 'e': (a, b, c),
- 'f': ABC(a=a, b=b, c=c),
- 'g': {'a': a.name, 'c': c, 'b': b}})
+ res = sess.run({
+ 'd': [a, b, c],
+ 'e': (a, b, c),
+ 'f': ABC(a=a, b=b, c=c),
+ 'g': {
+ 'a': a.name,
+ 'c': c,
+ 'b': b
+ }
+ })
self.assertTrue(isinstance(res, dict))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res['d'], list))
@@ -516,8 +540,7 @@ class SessionTest(test_util.TensorFlowTestCase):
values = np.array([1.0, 2.0]).astype(np.float32)
shape = np.array([7, 9, 2]).astype(np.int64)
sp = sparse_tensor.SparseTensor(
- constant_op.constant(indices),
- constant_op.constant(values),
+ constant_op.constant(indices), constant_op.constant(values),
constant_op.constant(shape))
# Single fetch, use as tuple
sp_out = s.run(sp)
@@ -587,14 +610,17 @@ class SessionTest(test_util.TensorFlowTestCase):
sp = sparse_tensor.SparseTensor(
array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
array_ops.placeholder(dtype=np.float32, shape=(2,)),
- array_ops.placeholder(dtype=np.int64, shape=(3,)),)
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),
+ )
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
sp_shape = array_ops.identity(sp.dense_shape)
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
@@ -605,20 +631,23 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp_out.dense_shape, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
# Feed SparseTensorValue and fetch sp directly.
- sp_out = s.run(
- sp, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp_out = s.run(sp, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp_out.indices, indices)
self.assertAllEqual(sp_out.values, values)
self.assertAllEqual(sp_out.dense_shape, shape)
@@ -635,20 +664,24 @@ class SessionTest(test_util.TensorFlowTestCase):
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
@@ -666,20 +699,24 @@ class SessionTest(test_util.TensorFlowTestCase):
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
@@ -689,9 +726,8 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
shape = np.array([7, 9, 2]).astype(np.int64)
- sp = array_ops.sparse_placeholder(dtype=np.float32,
- shape=shape,
- name='placeholder1')
+ sp = array_ops.sparse_placeholder(
+ dtype=np.float32, shape=shape, name='placeholder1')
self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape)
sp_indices = array_ops.identity(sp.indices)
@@ -699,7 +735,9 @@ class SessionTest(test_util.TensorFlowTestCase):
sp_shape = array_ops.identity(sp.dense_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
@@ -745,33 +783,34 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = np.array([7, 9, 2]).astype(np.int64)
ind = ops.IndexedSlices(
- array_ops.placeholder(dtype=np.float32,
- shape=(2,)),
- array_ops.placeholder(dtype=np.int64,
- shape=(2, 3)),
- array_ops.placeholder(dtype=np.int64,
- shape=(3,)),)
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),
+ )
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind_dense_shape = array_ops.identity(ind.dense_shape)
ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
# Feed with tuple
values_out, indices_out, dense_shape_out = s.run(
- [ind_values, ind_indices, ind_dense_shape],
- {ind: (values, indices, dense_shape)})
+ [ind_values, ind_indices, ind_dense_shape], {
+ ind: (values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue
values_out, indices_out, dense_shape_out = s.run(
- [ind_values, ind_indices, ind_dense_shape],
- {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ [ind_values, ind_indices, ind_dense_shape], {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
- ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
- dense_shape)})
+ ind2_out = s.run(ind2, {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)
@@ -816,28 +855,27 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = None
ind = ops.IndexedSlices(
- array_ops.placeholder(dtype=np.float32,
- shape=(2,)),
- array_ops.placeholder(dtype=np.int64,
- shape=(2, 3)),
- None)
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind2 = ops.IndexedSlices(ind_values, ind_indices)
# Feed with tuple
- values_out, indices_out = s.run(
- [ind_values, ind_indices], {ind: (values, indices)})
+ values_out, indices_out = s.run([ind_values, ind_indices], {
+ ind: (values, indices)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue
- values_out, indices_out = s.run(
- [ind_values, ind_indices],
- {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ values_out, indices_out = s.run([ind_values, ind_indices], {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
- ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
- dense_shape)})
+ ind2_out = s.run(ind2, {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)
@@ -986,8 +1024,9 @@ class SessionTest(test_util.TensorFlowTestCase):
constructed_events = [threading.Event() for _ in range(10)]
continue_event = threading.Event()
for i, constructed_event in enumerate(constructed_events):
- t = self.checkedThread(target=self._testDefaultGraphInThread,
- args=(constructed_event, continue_event, i))
+ t = self.checkedThread(
+ target=self._testDefaultGraphInThread,
+ args=(constructed_event, continue_event, i))
threads.append(t)
for t in threads:
t.start()
@@ -1006,6 +1045,7 @@ class SessionTest(test_util.TensorFlowTestCase):
ev.wait()
val = c.eval(session=sess)
self.assertEqual(val, 5.0)
+
threads = [self.checkedThread(target=run_step) for _ in range(100)]
for t in threads:
t.start()
@@ -1038,11 +1078,10 @@ class SessionTest(test_util.TensorFlowTestCase):
def testGraphDef(self):
with session.Session() as sess:
- self.assertProtoEquals(
- 'versions { producer: %d min_consumer: %d }' % (
- versions.GRAPH_DEF_VERSION,
- versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
- sess.graph_def)
+ self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
+ (versions.GRAPH_DEF_VERSION,
+ versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
+ sess.graph_def)
c = constant_op.constant(5.0, name='c')
self.assertEquals(len(sess.graph_def.node), 1)
d = constant_op.constant(6.0, name='d')
@@ -1072,6 +1111,7 @@ class SessionTest(test_util.TensorFlowTestCase):
lambda e: 'Attempted to use a closed Session.' in str(e)):
while True:
sess.run(c)
+
t = threading.Thread(target=update_thread)
t.start()
time.sleep(0.1)
@@ -1177,17 +1217,11 @@ class SessionTest(test_util.TensorFlowTestCase):
def testFeedAndFetch(self):
with session.Session() as sess:
- for dtype in [dtypes.float16,
- dtypes.float32,
- dtypes.float64,
- dtypes.int32,
- dtypes.uint8,
- dtypes.int16,
- dtypes.int8,
- dtypes.int64,
- dtypes.bool,
- dtypes.complex64,
- dtypes.complex128]:
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
+ dtypes.complex64, dtypes.complex128
+ ]:
for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
np_dtype = dtype.as_numpy_dtype
@@ -1206,13 +1240,19 @@ class SessionTest(test_util.TensorFlowTestCase):
np_array = np_array.astype(np_dtype)
self.assertAllEqual(np_array,
- sess.run(out_t, feed_dict={feed_t: np_array}))
+ sess.run(out_t, feed_dict={
+ feed_t: np_array
+ }))
# Check that we can also get the feed back.
self.assertAllEqual(np_array,
- sess.run(feed_t, feed_dict={feed_t: np_array}))
+ sess.run(feed_t, feed_dict={
+ feed_t: np_array
+ }))
# Also check that we can get both back.
- out_v, feed_v = sess.run([out_t, feed_t],
- feed_dict={feed_t: np_array})
+ out_v, feed_v = sess.run(
+ [out_t, feed_t], feed_dict={
+ feed_t: np_array
+ })
self.assertAllEqual(np_array, out_v)
self.assertAllEqual(np_array, feed_v)
@@ -1257,9 +1297,11 @@ class SessionTest(test_util.TensorFlowTestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
- self.assertAllClose(
- 42.0,
- tensor_runner(41.0, options=run_options, run_metadata=run_metadata))
+ self.assertAllClose(42.0,
+ tensor_runner(
+ 41.0,
+ options=run_options,
+ run_metadata=run_metadata))
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
def testFeedError(self):
@@ -1296,8 +1338,9 @@ class SessionTest(test_util.TensorFlowTestCase):
size = 1
for s in shape:
size *= s
- c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)],
- dtype=np.object).reshape(shape) if size > 0 else []
+ c_list = np.array(
+ [compat.as_bytes(str(i)) for i in xrange(size)],
+ dtype=np.object).reshape(shape) if size > 0 else []
c = constant_op.constant(c_list)
self.assertAllEqual(c.eval(), c_list)
@@ -1307,13 +1350,16 @@ class SessionTest(test_util.TensorFlowTestCase):
size = 1
for s in shape:
size *= s
- c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)],
- dtype=np.object).reshape(shape)
+ c_list = np.array(
+ [compat.as_bytes(str(i)) for i in xrange(size)],
+ dtype=np.object).reshape(shape)
feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
c = array_ops.identity(feed_t)
self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
- self.assertAllEqual(sess.run(feed_t, feed_dict={feed_t: c_list}),
- c_list)
+ self.assertAllEqual(
+ sess.run(feed_t, feed_dict={
+ feed_t: c_list
+ }), c_list)
c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
self.assertAllEqual(c_v, c_list)
self.assertAllEqual(feed_v, c_list)
@@ -1329,8 +1375,10 @@ class SessionTest(test_util.TensorFlowTestCase):
def testStringFeedWithUnicode(self):
with session.Session():
- c_list = [u'\n\x01\x00', u'\n\x00\x01',
- u'\u26a3 unicode', u'\U0001f60e deal with it']
+ c_list = [
+ u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
+ u'\U0001f60e deal with it'
+ ]
feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
c = array_ops.identity(feed_t)
@@ -1423,9 +1471,10 @@ class SessionTest(test_util.TensorFlowTestCase):
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0),
- options=run_options,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0),
+ options=run_options,
+ run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
@@ -1439,23 +1488,26 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session() as sess:
# all combinations are valid
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
- sess.run(constant_op.constant(1.0), options=None,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0), options=None, run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0), options=run_options,
- run_metadata=None)
+ sess.run(
+ constant_op.constant(1.0), options=run_options, run_metadata=None)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0), options=run_options,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0),
+ options=run_options,
+ run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
def testFeedShapeCompatibility(self):
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
- if ops._USE_C_API: return
+ if ops._USE_C_API:
+ return
with session.Session() as sess:
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
@@ -1499,8 +1551,11 @@ class SessionTest(test_util.TensorFlowTestCase):
d = math_ops.multiply(c, c)
for step in xrange(120):
run_metadata = config_pb2.RunMetadata()
- sess.run(d, feed_dict={a: 1.0},
- options=run_options, run_metadata=run_metadata)
+ sess.run(
+ d,
+ feed_dict={a: 1.0},
+ options=run_options,
+ run_metadata=run_metadata)
if step == 99:
self.assertTrue(run_metadata.HasField('cost_graph'))
else:
@@ -1569,8 +1624,7 @@ class SessionTest(test_util.TensorFlowTestCase):
def testTimeoutWithShortOperations(self):
num_epochs = 5
- q = data_flow_ops.FIFOQueue(
- capacity=50, dtypes=[dtypes.int32], shapes=[()])
+ q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
# Use a 10-second timeout, which should be longer than any
@@ -1582,7 +1636,9 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(sess.run(q.size()), num_epochs * 2)
def testRegisterFetchAndFeedConversionFunctions(self):
+
class SquaredTensor(object):
+
def __init__(self, tensor):
self.sq = math_ops.square(tensor)
@@ -1591,24 +1647,27 @@ class SessionTest(test_util.TensorFlowTestCase):
feed_fn2 = lambda feed: [feed.sq]
session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
- feed_fn1, feed_fn2)
+ feed_fn1, feed_fn2)
with self.assertRaises(ValueError):
- session.register_session_run_conversion_functions(SquaredTensor,
- fetch_fn, feed_fn1, feed_fn2)
+ session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
+ feed_fn1, feed_fn2)
with self.test_session() as sess:
np1 = np.array([1.0, 1.5, 2.0, 2.5])
np2 = np.array([3.0, 3.5, 4.0, 4.5])
squared_tensor = SquaredTensor(np2)
squared_eval = sess.run(squared_tensor)
self.assertAllClose(np2 * np2, squared_eval)
- squared_eval = sess.run(squared_tensor, feed_dict={
- squared_tensor : np1 * np1})
+ squared_eval = sess.run(
+ squared_tensor, feed_dict={
+ squared_tensor: np1 * np1
+ })
self.assertAllClose(np1 * np1, squared_eval)
partial_run = sess.partial_run_setup([squared_tensor], [])
squared_eval = sess.partial_run(partial_run, squared_tensor)
self.assertAllClose(np2 * np2, squared_eval)
def testDefaultLogDevicePlacement(self):
+
class CaptureStderr(str):
"""Class to capture stderr from C++ shared library."""
@@ -1719,6 +1778,7 @@ class SessionTest(test_util.TensorFlowTestCase):
def runTestAddFunctionToSession(self, target=''):
"""Add a function to a session after the graph has already been run."""
+
@function.Defun(dtypes.float32)
def foo(x):
return x + 1
@@ -1753,6 +1813,7 @@ class SessionTest(test_util.TensorFlowTestCase):
TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
sess.run(a, feed_dict={a: 1})
+
class GraphMutationTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -1803,8 +1864,7 @@ class GraphMutationTest(test_util.TensorFlowTestCase):
with session.Session(graph=g) as sess:
self.assertAllEqual(1.0, sess.run(b))
- b.op._set_attr('DstT',
- attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
+ b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
'Cast.*was changed by setting attribute after it was run'):
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5fb389cf92..43cbde69d9 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -59,7 +59,7 @@ tf_py_test(
tf_py_test(
name = "dataset_from_generator_op_test",
- size = "small",
+ size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 53c8be1d1d..bd80b9dbf5 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -50,8 +51,9 @@ class BatchDatasetTest(test.TestCase):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(count).batch(batch_size).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -67,7 +69,7 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
- self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -82,12 +84,12 @@ class BatchDatasetTest(test.TestCase):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
- self.assertAllEqual(component[(i*8 + j) % 7]**2,
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -188,33 +190,34 @@ class BatchDatasetTest(test.TestCase):
sess.run(get_next)
def testBatchShapeError(self):
+
def generator():
yield [1.0, 2.0, 3.0]
yield [4.0, 5.0, 6.0]
yield [7.0, 8.0, 9.0, 10.0]
- iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32,
- output_shapes=[None])
- .batch(3)
- .make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).batch(3)
+ .make_initializable_iterator())
next_element = iterator.get_next()
with self.test_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r"Cannot batch tensors with different shapes in component 0. "
- r"First element had shape \[3\] and element 2 had shape \[4\]."):
+ r'Cannot batch tensors with different shapes in component 0. '
+ r'First element had shape \[3\] and element 2 had shape \[4\].'):
sess.run(next_element)
def testPaddedBatchDataset(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- 4,
- padded_shapes=padded_shape).make_initializable_iterator())
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens)
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=padded_shape).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -222,35 +225,40 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result)
self.assertEqual((4, padded_len), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test with random sequence lengths, and constant padding.
- sess.run(init_op, feed_dict={padded_shape: [25],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [25],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
self.assertEqual((4, 25), result.shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test correct handling of empty tensors.
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: [0, 0, 0, 0]})
+ sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -258,8 +266,7 @@ class BatchDatasetTest(test.TestCase):
# Test error handling with constant sequence lengths, and
# too-short padding.
- sess.run(init_op, feed_dict={padded_shape: [5],
- seq_lens: [6, 5, 5, 5]})
+ sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
with self.assertRaises(errors.DataLossError):
result = sess.run(get_next)
@@ -270,11 +277,13 @@ class BatchDatasetTest(test.TestCase):
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
- iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
- .padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>")).make_initializable_iterator())
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
+ .padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, '<end>')).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -282,25 +291,46 @@ class BatchDatasetTest(test.TestCase):
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(init_op, feed_dict={padded_shape: [-1],
- seq_lens: random_seq_lens})
+ sess.run(
+ init_op, feed_dict={
+ padded_shape: [-1],
+ seq_lens: random_seq_lens
+ })
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
for j in range(4):
- seq_len = random_seq_lens[(i*4)+j]
+ seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[0][j, seq_len:],
[-1] * (padded_len - seq_len))
self.assertAllEqual(result[1][j, :seq_len],
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[1][j, seq_len:],
- [b"<end>"] * (padded_len - seq_len))
+ [b'<end>'] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testPaddedBatchDatasetUnicode(self):
+ # See GitHub issue 16149
+ def generator():
+ data = [[u'Простой', u'тест', u'юникода'],
+ [u'никогда', u'не', u'бывает', u'простым']]
+
+ for seq in data:
+ yield seq, [0, 1, 2, 3]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, (dtypes.string, dtypes.int32),
+ (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
+ padded_dataset = dataset.padded_batch(
+ 2, padded_shapes=([None], [None]), padding_values=('', 0))
+ with self.test_session() as sess:
+ next_element = padded_dataset.make_one_shot_iterator().get_next()
+ sess.run(next_element)
+
def testPaddedBatchDatasetShapeSpecifications(self):
int_placeholder = array_ops.placeholder(dtypes.int32)
float_placeholder = array_ops.placeholder(dtypes.float32)
@@ -324,15 +354,16 @@ class BatchDatasetTest(test.TestCase):
constant_op.constant([-1, -1], dtype=dtypes.int64),
constant_op.constant([37], dtype=dtypes.int64)))
- for dataset in [dynamic_padding_from_tensor_shapes,
- dynamic_padding_from_lists,
- dynamic_padding_from_lists_with_minus_one,
- dynamic_padding_from_tensors]:
+ for dataset in [
+ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+ dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+ ]:
self.assertEqual([None, None], dataset.output_shapes[0].as_list())
self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
def testPaddedBatchSparseError(self):
+
def _map_fn(i):
return sparse_tensor.SparseTensorValue(
indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
@@ -341,5 +372,5 @@ class BatchDatasetTest(test.TestCase):
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
-if __name__ == "__main__":
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 0594c6d6a7..c1ba67e474 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -201,7 +201,7 @@ class Dataset(object):
tensors: A nested structure of tensors.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TensorDataset(tensors)
@@ -214,7 +214,7 @@ class Dataset(object):
0th dimension.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TensorSliceDataset(tensors)
@@ -227,7 +227,7 @@ class Dataset(object):
sparse_tensor: A `tf.SparseTensor`.
Returns:
- A `Dataset` of rank-(N-1) sparse tensors.
+ Dataset: A `Dataset` of rank-(N-1) sparse tensors.
"""
return SparseTensorSliceDataset(sparse_tensor)
@@ -313,7 +313,7 @@ class Dataset(object):
`generator`.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
@@ -456,7 +456,7 @@ class Dataset(object):
len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
Returns:
- A `RangeDataset`.
+ Dataset: A `RangeDataset`.
Raises:
ValueError: if len(args) == 0.
@@ -500,7 +500,7 @@ class Dataset(object):
datasets: A nested structure of datasets.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ZipDataset(datasets)
@@ -526,7 +526,7 @@ class Dataset(object):
dataset: `Dataset` to be concatenated.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ConcatenateDataset(self, dataset)
@@ -538,7 +538,7 @@ class Dataset(object):
maximum number elements that will be buffered when prefetching.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return PrefetchDataset(self, buffer_size)
@@ -561,7 +561,7 @@ class Dataset(object):
the filename pattern that will be matched.
Returns:
- A `Dataset` of strings corresponding to file names.
+ Dataset: A `Dataset` of strings corresponding to file names.
"""
return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
@@ -578,7 +578,7 @@ class Dataset(object):
indefinitely.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return RepeatDataset(self, count)
@@ -602,7 +602,7 @@ class Dataset(object):
iterated over. (Defaults to `True`.)
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
@@ -615,7 +615,7 @@ class Dataset(object):
If a filename is not provided, the dataset will be cached in memory.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return CacheDataset(self, filename)
@@ -629,7 +629,7 @@ class Dataset(object):
dataset, the new dataset will contain all elements of this dataset.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TakeDataset(self, count)
@@ -644,7 +644,7 @@ class Dataset(object):
is -1, skips the entire dataset.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return SkipDataset(self, count)
@@ -691,7 +691,7 @@ class Dataset(object):
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
@@ -735,7 +735,7 @@ class Dataset(object):
consecutive elements of this dataset to combine in a single batch.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return BatchDataset(self, batch_size)
@@ -764,7 +764,7 @@ class Dataset(object):
the empty string for string types.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
@@ -780,7 +780,7 @@ class Dataset(object):
specified, elements will be processed sequentially.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
if num_parallel_calls is None:
return MapDataset(self, map_func)
@@ -796,7 +796,7 @@ class Dataset(object):
`Dataset`.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return FlatMapDataset(self, map_func)
@@ -865,7 +865,7 @@ class Dataset(object):
input element before cycling to another input element.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return InterleaveDataset(self, map_func, cycle_length, block_length)
@@ -878,7 +878,7 @@ class Dataset(object):
scalar `tf.bool` tensor.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return FilterDataset(self, predicate)
@@ -902,7 +902,7 @@ class Dataset(object):
returns a `Dataset`.
Returns:
- The `Dataset` returned by applying `transformation_func` to this dataset.
+ Dataset: The `Dataset` returned by applying `transformation_func` to this dataset.
"""
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py
index 6fd89e018a..b6c7280a41 100644
--- a/tensorflow/python/debug/lib/debug_gradients_test.py
+++ b/tensorflow/python/debug/lib/debug_gradients_test.py
@@ -39,7 +39,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
def setUp(self):
self.sess = session.Session()
- with self.sess:
+ with self.sess.as_default():
self.u = variables.Variable(2.0, name="u")
self.v = variables.Variable(3.0, name="v")
self.w = math_ops.multiply(self.u.value(), self.v.value(), name="w")
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index f470e18120..9e3382d4f3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1,8 +1,7 @@
licenses(["notice"]) # Apache 2.0
-load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load(
"//tensorflow/tools/test:performance.bzl",
"tf_py_logged_benchmark",
@@ -423,6 +422,22 @@ cuda_py_test(
],
)
+py_test(
+ name = "pywrap_tfe_test",
+ srcs = ["pywrap_tfe_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":backprop",
+ ":context",
+ ":test",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:random_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets.
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 9849f0f322..75526ba9c1 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,11 +42,25 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-
CPU = "/device:CPU:0"
GPU = "/device:GPU:0"
+def record_gradient_callback(inputs, attrs, results):
+ return backprop._record_gradient("MatMul", inputs, attrs, results, None)
+
+
+def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False):
+ ctx = context.context()
+ assert not ctx.in_graph_mode(
+ ), "The prototype doesn't contain C code for graph construction"
+ ctx_handle = ctx._handle # pylint: disable=protected-access
+
+ return pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx_handle, None, "MatMul", record_gradient_callback, a, b,
+ "transpose_a", transpose_a, "transpose_b", transpose_b)[0]
+
+
class MicroBenchmarks(test.Benchmark):
def __init__(self):
@@ -222,6 +236,14 @@ class MicroBenchmarks(test.Benchmark):
gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
self._run(func, num_iters)
+ def _benchmark_tfe_py_fastpath_execute_matmul(self, m, transpose_b,
+ num_iters):
+
+ def func():
+ c_tfe_py_fastpath_execute(m, m, transpose_b=transpose_b)
+
+ self._run(func, num_iters)
+
def _benchmark_tfe_py_execute_matmul(self, m, transpose_b, num_iters):
inputs = [m, m]
# pylint: disable=protected-access
@@ -257,6 +279,12 @@ class MicroBenchmarks(test.Benchmark):
self._benchmark_gen_math_ops_matmul(
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+ def benchmark_tfe_py_fastpath_execute_matmul_2_by_2_CPU(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_tfe_py_fastpath_execute_matmul(
+ m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+
def benchmark_tfe_py_execute_matmul_2_by_2_CPU(self):
with context.device(CPU):
m = self._m_2_by_2.cpu()
@@ -320,6 +348,12 @@ class MicroBenchmarks(test.Benchmark):
self._benchmark_gen_math_ops_matmul(
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
+ def benchmark_tfe_py_fastpath_execute_matmul_100_by_784_CPU(self):
+ with context.device(CPU):
+ m = self._m_100_by_784.cpu()
+ self._benchmark_tfe_py_fastpath_execute_matmul(
+ m, transpose_b=True, num_iters=self._num_iters_100_by_784)
+
def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self):
with context.device(CPU):
m = self._m_100_by_784.cpu()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index cbf588336d..b6c7d82323 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -49,6 +49,8 @@ _MAXINT32 = 2**31 - 1
DEVICE_PLACEMENT_EXPLICIT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_EXPLICIT
DEVICE_PLACEMENT_WARN = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_WARN
DEVICE_PLACEMENT_SILENT = pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT
+DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
+ pywrap_tensorflow.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
# TODO(agarwal): better name ?
@@ -122,6 +124,8 @@ class Context(object):
right device but raises a warning.
tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
hide performance problems.
+ tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
+ raising errors on the other ones.
"""
self._eager_context = _EagerContext()
self._context_handle = None
@@ -411,6 +415,20 @@ class Context(object):
self._initialize_handle_and_devices()
pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
+ @tf_contextlib.contextmanager
+ def device_policy(self, policy):
+ if not self._context_handle:
+ self._initialize_handle_and_devices()
+ old = pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
+ self._context_handle)
+ pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
+ self._handle, policy)
+ try:
+ yield
+ finally:
+ pywrap_tensorflow.TFE_ContextSetThreadLocalDevicePlacementPolicy(
+ self._handle, old)
+
def disable_run_metadata(self):
"""Disables tracing of op execution via RunMetadata."""
if not self._context_handle:
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index a70fa72804..ee3c10633e 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import nn_ops
def execute(op_name, num_outputs, inputs, attrs=None):
@@ -112,6 +113,14 @@ class TFETest(test_util.TensorFlowTestCase):
# is enabled; the stack entry should reflect this fact.
self.assertFalse(stack_entry.is_building_function)
+ def testInt32GPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+ with ops.device('gpu:0'):
+ xent = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits=[[0.0, 0.0]], labels=[0])
+ self.assertAllClose(xent, [0.69314718])
+
def _runInThread(self, target, args):
t = threading.Thread(target=target, args=args)
try:
@@ -173,6 +182,15 @@ class TFETest(test_util.TensorFlowTestCase):
with self.assertRaises(RuntimeError):
x.gpu(context.context().num_gpus() + 1)
+ def testCopyScope(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+ constant = constant_op.constant(1.0)
+ with ops.device('gpu:0'):
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ c = constant + 1.0
+ self.assertAllEqual(c, 2.0)
+
def testNumpyForceCPU(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index f755434ad7..81b1f6f12a 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -120,6 +120,8 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
tensor_map = _scoped_captures.tensors
if tensor_map is None:
# Capturing is not enabled.
+ if value.dtype == dtypes_module.resource:
+ return value
return constant_op.constant(value.numpy())
if type(value) == ops.Tensor and value.graph is default_graph:
# The tensor has already been converted and captured. The type check
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9b08a35ff1..0babc29f17 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -400,10 +400,11 @@ class FunctionTest(test.TestCase):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = function.defun(array_ops.reshape)
- value = constant_op.constant([1., 2.]).gpu()
+ value = constant_op.constant([1., 2.])
shape = constant_op.constant([2, 1]).gpu()
with self.assertRaises(errors.InvalidArgumentError):
- reshape(value, shape)
+ with ops.device('gpu:0'):
+ reshape(value, shape)
def testDifferentiableFunctionNoneOutputs(self):
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index f8c5037dcf..f2e70341d9 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -24,7 +24,6 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@@ -246,15 +245,6 @@ class OpsTest(test_util.TensorFlowTestCase):
reshaped = array_ops.reshape(value, shape)
self.assertAllEqual([[1], [2]], reshaped.cpu())
- # And if the shape is in device memory, it should complain
- # TODO(ashankar): Revisit this - perhaps instead of complaining,
- # it should implicitly copy the tensor to host memory?
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- 'cannot compute Reshape as input #1 was expected to be on.*'
- 'using.*DEVICE_PLACEMENT_SILENT'):
- reshaped = array_ops.reshape(value, shape.gpu())
-
def testInt64(self):
# Fill requires the first input to be an int32 tensor.
self.assertAllEqual(
diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h
index f9dfdf0408..d27b00139d 100644
--- a/tensorflow/python/eager/python_eager_op_gen.h
+++ b/tensorflow/python/eager/python_eager_op_gen.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
+#ifndef TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
+#define TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
#include <string>
#include <vector>
@@ -40,4 +40,4 @@ string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
+#endif // TENSORFLOW_PYTHON_EAGER_PYTHON_EAGER_OP_GEN_H_
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index cecef42603..4aea134fa9 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -131,6 +131,28 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
PyObject* target, PyObject* sources,
PyObject* output_gradients, TF_Status* status);
+// Execute a tensorflow operation assuming that all provided inputs are
+// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
+// it will simply fail with a NotImplementedError.
+//
+// The first PyObject* is unused.
+// The "args" PyObject* is meant to be a tuple with the following structure:
+// Item 1: The TFE Context
+// Item 2: device_name: Name of the device on which to execute the operation,
+// or NULL for automatic selection.
+// Item 3: op_name: Name of the TensorFlow op to execute.
+// Item 4: record_gradient_callback: Callback that records the gradient of the
+// result.
+// The callback takes (inputs, attrs, result) - all sequences and
+// records the gradient.
+// Item 5 onwards: inputs - This is a list of inputs followed by a list of
+// attrs. It is not necessary for type attrs to be present.
+//
+// This is named _C since there doesn't seem to be any way to make it visible
+// in the SWIG interface without renaming due to the use of the %native
+// directive.
+PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args);
+
// Returns the set of variables watched by the given tape.
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 38c3cb2174..647f03351d 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -21,12 +21,16 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
using tensorflow::string;
+using tensorflow::strings::Printf;
namespace {
@@ -289,14 +293,12 @@ bool SetOpAttrScalar(TFE_Context* ctx, TFE_Op* op, const char* key,
return true;
}
-void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs,
+// start_index is the index at which the Tuple/List attrs will start getting
+// processed.
+void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
TF_Status* out_status) {
if (attrs == Py_None) return;
- if (!PyTuple_Check(attrs)) {
- TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Expecting an attrs tuple.");
- return;
- }
- Py_ssize_t len = PyTuple_GET_SIZE(attrs);
+ Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
if ((len & 1) != 0) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
"Expecting attrs tuple to have even length.");
@@ -304,8 +306,8 @@ void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs,
}
// Parse attrs
for (Py_ssize_t i = 0; i < len; i += 2) {
- PyObject* py_key = PyTuple_GET_ITEM(attrs, i);
- PyObject* py_value = PyTuple_GET_ITEM(attrs, i + 1);
+ PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
+ PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
#if PY_MAJOR_VERSION >= 3
const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
: PyUnicode_AsUTF8(py_key);
@@ -329,7 +331,6 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
static tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
static tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0;
-
} // namespace
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
@@ -346,7 +347,7 @@ void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
}
}
if (TF_GetCode(out_status) == TF_OK) {
- SetOpAttrs(ctx, op, attrs, out_status);
+ SetOpAttrs(ctx, op, attrs, 0, out_status);
}
Py_BEGIN_ALLOW_THREADS;
if (TF_GetCode(out_status) == TF_OK) {
@@ -542,10 +543,10 @@ static PyTypeObject TFE_Py_Tape_Type = {
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
-static std::unordered_set<TFE_Py_Tape*>* tape_set = nullptr;
-std::unordered_set<TFE_Py_Tape*>* GetTapeSet() {
+static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr;
+tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
if (tape_set == nullptr) {
- tape_set = new std::unordered_set<TFE_Py_Tape*>;
+ tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
}
return tape_set;
}
@@ -636,8 +637,8 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
if (*ThreadTapeIsStopped()) {
Py_RETURN_FALSE;
}
- auto* tape_set = GetTapeSet();
- if (tape_set->empty()) {
+ auto* tape_set_ptr = GetTapeSet();
+ if (tape_set_ptr->empty()) {
Py_RETURN_FALSE;
}
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -654,7 +655,8 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) {
tensor_ids.push_back(FastTensorId(item));
}
Py_DECREF(seq);
- for (TFE_Py_Tape* tape : *tape_set) {
+ auto tape_set = *tape_set_ptr;
+ for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids)) {
Py_RETURN_TRUE;
}
@@ -760,11 +762,13 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensors,
PyObject* backward_function) {
- auto* set = GetTapeSet();
- if (set->empty() || *ThreadTapeIsStopped()) {
+ if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) {
return;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
+ if (PyErr_Occurred()) {
+ return;
+ }
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
@@ -796,7 +800,8 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
return;
}
- for (TFE_Py_Tape* tape : *set) {
+ auto set = *GetTapeSet();
+ for (TFE_Py_Tape* tape : set) {
Py_INCREF(backward_function);
tape->tape->RecordOperation(
op_type_str, output_info, input_ids, backward_function,
@@ -805,7 +810,10 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
- for (TFE_Py_Tape* tape : *GetTapeSet()) {
+ // Note: making a copy because deleting the trace can trigger a change to the
+ // set of tapes by allowing python's garbage collector to run.
+ auto tape_set = *GetTapeSet();
+ for (TFE_Py_Tape* tape : tape_set) {
tape->tape->DeleteTrace(tensor_id);
}
}
@@ -974,7 +982,6 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
return list;
}
-
PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
PyObject* target, PyObject* sources,
PyObject* output_gradients, TF_Status* status) {
@@ -1029,3 +1036,195 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
Py_INCREF(Py_None);
return Py_None;
}
+
+namespace {
+static const int kFastPathExecuteInputStartIndex = 4;
+
+bool CheckEagerTensors(PyObject* seq, int start_index, int num_to_check) {
+ for (int i = start_index; i < start_index + num_to_check; i++) {
+ PyObject* item = PyTuple_GET_ITEM(seq, i);
+ if (!EagerTensor_CheckExact(item)) return false;
+ }
+
+ return true;
+}
+
+const tensorflow::OpDef* GetOpDef(PyObject* py_op_name) {
+ const char* op_name = TFE_GetPythonString(py_op_name);
+ if (op_name == nullptr) {
+ PyErr_SetString(PyExc_TypeError,
+ Printf("expected a string for op_name, got %s instead",
+ py_op_name->ob_type->tp_name)
+ .c_str());
+ return nullptr;
+ }
+
+ const tensorflow::OpRegistrationData* op_reg_data = nullptr;
+ const tensorflow::Status lookup_status =
+ tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+ if (MaybeRaiseExceptionFromStatus(lookup_status, nullptr)) {
+ return nullptr;
+ }
+ return &op_reg_data->op_def;
+}
+
+const char* GetDeviceName(PyObject* py_device_name) {
+ if (py_device_name != Py_None) {
+ return TFE_GetPythonString(py_device_name);
+ }
+ return nullptr;
+}
+
+bool MaybeRunRecordGradientCallback(const tensorflow::OpDef* op_def,
+ PyObject* args, PyObject* result,
+ PyObject* record_gradient_callback) {
+ if (*ThreadTapeIsStopped() || GetTapeSet()->empty() ||
+ record_gradient_callback == Py_None) {
+ return true;
+ }
+ if (!PyCallable_Check(record_gradient_callback)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ Printf(
+ "expected a function for record_gradient_callback, got %s instead",
+ record_gradient_callback->ob_type->tp_name)
+ .c_str());
+ return false;
+ }
+
+ PyObject* inputs = PyTuple_New(op_def->input_arg_size());
+ for (int i = 0; i < op_def->input_arg_size(); i++) {
+ auto* input = PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
+ Py_INCREF(input);
+ PyTuple_SET_ITEM(inputs, i, input);
+ }
+
+ int args_size = PyTuple_GET_SIZE(args);
+ int num_attrs =
+ args_size - op_def->input_arg_size() - kFastPathExecuteInputStartIndex;
+ PyObject* attrs = PyTuple_New(num_attrs);
+ for (int i = 0; i < num_attrs; i++) {
+ auto* attr = PyTuple_GET_ITEM(
+ args, kFastPathExecuteInputStartIndex + op_def->input_arg_size() + i);
+ Py_INCREF(attr);
+ PyTuple_SET_ITEM(attrs, i, attr);
+ }
+
+ PyObject* callback_args = Py_BuildValue("OOO", inputs, attrs, result);
+ PyObject_CallObject(record_gradient_callback, callback_args);
+
+ Py_DECREF(inputs);
+ Py_DECREF(callback_args);
+ Py_DECREF(attrs);
+ return true;
+}
+} // namespace
+
+PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
+ TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
+ PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
+ const tensorflow::OpDef* op_def = GetOpDef(PyTuple_GET_ITEM(args, 2));
+ if (op_def == nullptr) return nullptr;
+ const char* device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
+ PyObject* record_gradient_callback = PyTuple_GET_ITEM(args, 3);
+
+ Py_ssize_t args_size = PyTuple_GET_SIZE(args);
+ if (args_size < kFastPathExecuteInputStartIndex) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ Printf("There must be at least %d items in the input tuple.",
+ kFastPathExecuteInputStartIndex)
+ .c_str());
+ return nullptr;
+ }
+
+ if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ Printf("Tuple size smaller than intended. Expected to be at least %d, "
+ "was %ld",
+ kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
+ args_size)
+ .c_str());
+ return nullptr;
+ }
+
+ if (!CheckEagerTensors(args, kFastPathExecuteInputStartIndex,
+ op_def->input_arg_size())) {
+ // TODO(nareshmodi): Maybe some other way of signalling that this should
+ // fall back?
+ PyErr_SetString(PyExc_NotImplementedError,
+ "This function does not handle the case of the path where "
+ "all inputs are not already EagerTensors.");
+ return nullptr;
+ }
+
+ TF_Status* status = TF_NewStatus();
+ TFE_Op* op = TFE_NewOp(ctx, op_def->name().c_str(), status);
+ auto cleaner = tensorflow::gtl::MakeCleanup([status, op] {
+ TF_DeleteStatus(status);
+ TFE_DeleteOp(op);
+ });
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ TFE_OpSetDevice(op, device_name, status);
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ // Add non-type attrs.
+ SetOpAttrs(ctx, op, args,
+ kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
+ status);
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ // Add type attrs and inputs.
+ for (int i = 0; i < op_def->input_arg_size(); i++) {
+ const auto& input_arg = op_def->input_arg(i);
+
+ PyObject* input =
+ PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
+ TFE_TensorHandle* input_handle = EagerTensor_Handle(input);
+
+ // The following code might set duplicate type attrs. This will result in
+ // the CacheKey for the generated AttrBuilder possibly differing from those
+ // where the type attrs are correctly set. Inconsistent CacheKeys for ops
+ // means that there might be unnecessarily duplicated kernels.
+ // TODO(nareshmodi): Fix this.
+ if (!input_arg.type_attr().empty()) {
+ TFE_OpSetAttrType(op, input_arg.type_attr().data(),
+ TFE_TensorHandleDataType(input_handle));
+ }
+
+ TFE_OpAddInput(op, input_handle, status);
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+ }
+
+ int num_retvals = op_def->output_arg_size();
+ tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
+
+ Py_BEGIN_ALLOW_THREADS;
+ TFE_Execute(op, retvals.data(), &num_retvals, status);
+ Py_END_ALLOW_THREADS;
+ if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ PyObject* result = PyTuple_New(num_retvals);
+ for (int i = 0; i < num_retvals; ++i) {
+ PyTuple_SET_ITEM(result, i, EagerTensorFromHandle(retvals[i]));
+ }
+
+ if (!MaybeRunRecordGradientCallback(op_def, args, result,
+ record_gradient_callback)) {
+ return nullptr;
+ }
+
+ return result;
+}
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
new file mode 100644
index 0000000000..d4f4ed592f
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for low-level eager execution primitives."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+
+
+def record_gradient_callback(inputs, attrs, results):
+ return backprop._record_gradient("MatMul", inputs, attrs, results, None)
+
+
+def c_tfe_py_fastpath_execute(a, b, transpose_a=False, transpose_b=False):
+ ctx = context.context()
+ assert not ctx.in_graph_mode(
+ ), "The prototype doesn't contain C code for graph construction"
+ ctx_handle = ctx._handle # pylint: disable=protected-access
+
+ return pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx_handle, ctx.device_name, "MatMul", record_gradient_callback, a, b,
+ "transpose_a", transpose_a, "transpose_b", transpose_b)[0]
+
+
+class Tests(test.TestCase):
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_MatMulCorrectResponse(self):
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ b_2_by_2 = random_ops.random_uniform((2, 2))
+
+ a_100_by_784 = random_ops.random_uniform((100, 784))
+ b_100_by_784 = random_ops.random_uniform((100, 784))
+
+ self.assertAllClose(
+ math_ops.matmul(a_2_by_2, b_2_by_2),
+ c_tfe_py_fastpath_execute(a_2_by_2, b_2_by_2))
+ self.assertAllClose(
+ math_ops.matmul(a_100_by_784, b_100_by_784, transpose_b=True),
+ c_tfe_py_fastpath_execute(a_100_by_784, b_100_by_784, transpose_b=True))
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_TapeWrite(self):
+ with backprop.GradientTape(persistent=True) as tape:
+ a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
+ tape.watch(a_2_by_2)
+ z = c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2)
+ dz_dy = tape.gradient(z, [a_2_by_2])[0]
+ self.assertAllEqual(dz_dy.numpy(),
+ constant_op.constant(4.0, shape=[2, 2]).numpy())
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_MatMulSlowPath(self):
+ a_2_by_2 = random_ops.random_uniform((2, 2)).cpu().numpy()
+
+ with self.assertRaises(NotImplementedError):
+ c_tfe_py_fastpath_execute(a_2_by_2, a_2_by_2)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_InvalidInputs(self):
+ a_2_by_2 = random_ops.random_uniform((2, 2))
+ ctx = context.context()
+ assert not ctx.in_graph_mode(
+ ), "The prototype doesn't contain C code for graph construction"
+ ctx_handle = ctx._handle # pylint: disable=protected-access
+
+ with self.assertRaisesRegexp(ValueError,
+ "at least 4 items in the input tuple"):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+ "Identity")
+
+ with self.assertRaisesRegexp(ValueError,
+ "Expected to be at least 5, was 4"):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx_handle, ctx_handle, "Identity", record_gradient_callback)
+
+ with self.assertRaisesRegexp(TypeError, "expected a string for op_name"):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx_handle, ctx.device_name, ctx_handle, record_gradient_callback,
+ a_2_by_2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 2568d3dc05..0bd5a5dbaf 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -112,6 +112,19 @@ class TFETensorTest(test_util.TensorFlowTestCase):
numpy_tensor = np.asarray(tensor, dtype=np.int32)
self.assertAllEqual(numpy_tensor, [1, 2, 3])
+ def testNdimsAgreesWithNumpy(self):
+ numpy_tensor = np.asarray(1.0)
+ tensor = constant_op.constant(numpy_tensor)
+ self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
+
+ numpy_tensor = np.asarray([1.0, 2.0, 3.0])
+ tensor = constant_op.constant(numpy_tensor)
+ self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
+
+ numpy_tensor = np.asarray([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
+ tensor = constant_op.constant(numpy_tensor)
+ self.assertAllEqual(numpy_tensor.ndim, tensor.ndim)
+
def testCopy(self):
t = constant_op.constant(1.0)
tt = copy.copy(t)
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 6343615737..c519fd557a 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -267,6 +267,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -356,6 +357,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -602,6 +604,7 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
+ ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -624,8 +627,9 @@ py_library(
py_test(
name = "head_test",
- size = "small",
+ size = "medium",
srcs = ["canned/head_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
@@ -679,6 +683,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 0392ff9a71..0f274a23c0 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -22,7 +22,6 @@ import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.feature_column import feature_column as feature_column_lib
@@ -31,6 +30,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
@@ -280,6 +280,7 @@ class DNNClassifier(estimator.Estimator):
input_layer_partitioner=None,
config=None,
warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM,
):
"""Initializes a `DNNClassifier` instance.
@@ -323,19 +324,23 @@ class DNNClassifier(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
"""
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
- """Call the defined shared _dnn_model_fn and possibly warm-start."""
- estimator_spec = _dnn_model_fn(
+ """Call the defined shared _dnn_model_fn."""
+ return _dnn_model_fn(
features=features,
labels=labels,
mode=mode,
@@ -347,17 +352,10 @@ class DNNClassifier(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(DNNClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
class DNNRegressor(estimator.Estimator):
@@ -441,6 +439,7 @@ class DNNRegressor(estimator.Estimator):
input_layer_partitioner=None,
config=None,
warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM,
):
"""Initializes a `DNNRegressor` instance.
@@ -478,17 +477,20 @@ class DNNRegressor(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
"""
def _model_fn(features, labels, mode, config):
- """Call the defined shared _dnn_model_fn and possibly warm-start."""
- estimator_spec = _dnn_model_fn(
+ """Call the defined shared _dnn_model_fn."""
+ return _dnn_model_fn(
features=features,
labels=labels,
mode=mode,
head=head_lib. # pylint: disable=protected-access
_regression_head_with_mean_squared_error_loss(
- label_dimension=label_dimension, weight_column=weight_column),
+ label_dimension=label_dimension, weight_column=weight_column,
+ loss_reduction=loss_reduction),
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
@@ -496,14 +498,7 @@ class DNNRegressor(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(DNNRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 1d06a54a32..1a0f4c5c39 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -23,7 +23,6 @@ import math
import six
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import linear
@@ -34,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
@@ -309,7 +309,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
label_vocabulary=None,
input_layer_partitioner=None,
config=None,
- warm_start_from=None):
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
"""Initializes a DNNLinearCombinedClassifier instance.
Args:
@@ -356,6 +357,8 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -371,16 +374,18 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes,
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
- """Call the _dnn_linear_combined_model_fn and possibly warm-start."""
- estimator_spec = _dnn_linear_combined_model_fn(
+ """Call the _dnn_linear_combined_model_fn."""
+ return _dnn_linear_combined_model_fn(
features=features,
labels=labels,
mode=mode,
@@ -394,17 +399,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
dnn_dropout=dnn_dropout,
input_layer_partitioner=input_layer_partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(DNNLinearCombinedClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
class DNNLinearCombinedRegressor(estimator.Estimator):
@@ -490,7 +488,8 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
weight_column=None,
input_layer_partitioner=None,
config=None,
- warm_start_from=None):
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
"""Initializes a DNNLinearCombinedRegressor instance.
Args:
@@ -531,6 +530,8 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights are warm-started, and it is assumed that vocabularies and Tensor
names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -545,14 +546,15 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
'must be defined.')
def _model_fn(features, labels, mode, config):
- """Call the _dnn_linear_combined_model_fn and possibly warm-start."""
- estimator_spec = _dnn_linear_combined_model_fn(
+ """Call the _dnn_linear_combined_model_fn."""
+ return _dnn_linear_combined_model_fn(
features=features,
labels=labels,
mode=mode,
head=head_lib. # pylint: disable=protected-access
_regression_head_with_mean_squared_error_loss(
- label_dimension=label_dimension, weight_column=weight_column),
+ label_dimension=label_dimension, weight_column=weight_column,
+ loss_reduction=loss_reduction),
linear_feature_columns=linear_feature_columns,
linear_optimizer=linear_optimizer,
dnn_feature_columns=dnn_feature_columns,
@@ -562,14 +564,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
dnn_dropout=dnn_dropout,
input_layer_partitioner=input_layer_partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(DNNLinearCombinedRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 2bdec69303..706575985f 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -877,7 +877,7 @@ class BaseDNNWarmStartingTest(object):
# Create a second DNNClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
- # accumulator values that change). Use a a new FeatureColumn with a
+ # accumulator values that change). Use a new FeatureColumn with a
# different vocabulary for occupation.
new_vocab_list = ['doctor', 'consultant', 'engineer']
new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 204e1119f2..cb9e3fc6ca 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -24,6 +24,7 @@ import collections
import six
from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
@@ -371,6 +372,64 @@ def _check_logits_final_dim(logits, expected_logits_dimension):
return array_ops.identity(logits, name=scope)
+def _validate_loss_fn_args(loss_fn):
+ """Validates loss_fn arguments.
+
+ Required arguments: labels, logits.
+ Optional arguments: features.
+
+ Args:
+ loss_fn: The loss function.
+ Raises:
+ ValueError: If the signature is unexpected.
+ """
+ loss_fn_args = util.fn_args(loss_fn)
+ for required_arg in ['labels', 'logits']:
+ if required_arg not in loss_fn_args:
+ raise ValueError(
+ 'loss_fn must contain argument: {}. '
+ 'Given arguments: {}'.format(required_arg, loss_fn_args))
+ invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
+ if invalid_args:
+ raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
+
+
+def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
+ """Calls loss_fn and checks the returned shape.
+
+ Args:
+ loss_fn: The loss function.
+ labels: Processed labels Tensor.
+ logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension].
+ features: Features dict.
+ expected_loss_dim: The expected last dimension of loss Tensor.
+ Returns:
+ Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
+ """
+ loss_fn_args = util.fn_args(loss_fn)
+ kwargs = {}
+ if 'features' in loss_fn_args:
+ kwargs['features'] = features
+ with ops.name_scope(
+ None, 'call_loss_fn',
+ values=[labels, logits] + list(six.itervalues(features))):
+ unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
+ logits_shape = array_ops.shape(logits, name='logits_shape')
+ expected_loss_shape = array_ops.concat(
+ [logits_shape[:-1], [expected_loss_dim]], axis=0,
+ name='expected_loss_shape')
+ loss_shape = array_ops.shape(unweighted_loss, name='loss_shape')
+ check_loss_shape_op = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(loss_shape, expected_loss_shape)),
+ data=[
+ 'loss_fn must return Tensor of shape '
+ '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),
+ 'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape],
+ name='check_loss_shape')
+ with ops.control_dependencies([check_loss_shape_op]):
+ return array_ops.identity(unweighted_loss)
+
+
def _indicator_labels_mean(labels, weights=None, name=None):
with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:
labels = math_ops.to_float(labels, name='labels')
@@ -467,6 +526,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a '_Head' for multi class classification.
@@ -485,6 +545,12 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
@@ -499,6 +565,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
`label_vocabulary` is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -517,11 +584,14 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _MultiClassHeadWithSoftmaxCrossEntropyLoss(
n_classes=n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -533,6 +603,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
if (n_classes is None) or (n_classes <= 2):
raise ValueError('n_classes must be > 2: %s.' % n_classes)
@@ -540,6 +611,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
self._weight_column = weight_column
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -602,10 +674,15 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
labels = _check_dense_labels_match_logits_and_reshape(
labels=labels, logits=logits, expected_labels_dimension=1)
label_ids = self._label_ids(labels)
- unweighted_loss = losses.sparse_softmax_cross_entropy(
- labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
- # Restore the squeezed dim, so unweighted_loss matches the weights shape.
- unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=label_ids, logits=logits,
+ features=features, expected_loss_dim=1)
+ else:
+ unweighted_loss = losses.sparse_softmax_cross_entropy(
+ labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
+ # Restore the squeezed dim, so unweighted_loss matches the weights shape.
+ unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits)
training_loss = losses.compute_weighted_loss(
@@ -627,15 +704,15 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
For many applications, the shape is `[batch_size, logits_dimension]`.
labels: Labels integer or string `Tensor` with shape matching `logits`,
- namely `[D0, D1, ... DN, 1]`. `labels` is required argument when `mode`
- equals `TRAIN` or `EVAL`.
+ namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is
+ required argument when `mode` equals `TRAIN` or `EVAL`.
train_op_fn: Function that takes a scalar loss `Tensor` and returns
`train_op`. Required in TRAIN mode.
regularization_losses: A list of additional scalar losses to be added to
the training loss, such as regularization losses. These losses are
usually expressed as a batch average, so for best results users need to
- set `loss_reduction=MEAN_PER_ELEMENT` or
- `loss_reduction=SUM_BY_NONZERO_WEIGHTS` when creating the head to
+ set `loss_reduction=SUM_OVER_BATCH_SIZE` or
+ `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
avoid scaling errors.
Returns:
`EstimatorSpec`.
@@ -734,8 +811,12 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
- weight_column=None, thresholds=None, label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM, name=None):
+ weight_column=None,
+ thresholds=None,
+ label_vocabulary=None,
+ loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
+ name=None):
"""Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -755,6 +836,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -772,6 +859,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -795,11 +883,14 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
weight_column=weight_column,
thresholds=thresholds,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -811,11 +902,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
thresholds=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
self._weight_column = weight_column
self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -827,10 +920,10 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
return 1
def _eval_metric_ops(self, labels, logits, logistic, class_ids, weights,
- unreduced_loss):
+ unreduced_loss, regularization_loss):
with ops.name_scope(None, 'metrics',
(labels, logits, logistic, class_ids, weights,
- unreduced_loss)):
+ unreduced_loss, regularization_loss)):
keys = metric_keys.MetricKeys
labels_mean = _indicator_labels_mean(
labels=labels, weights=weights, name=keys.LABEL_MEAN)
@@ -870,6 +963,11 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
curve='PR',
name=keys.AUC_PR)
}
+ if regularization_loss is not None:
+ metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (
+ metrics_lib.mean(
+ values=regularization_loss,
+ name=keys.LOSS_REGULARIZATION))
for threshold in self._thresholds:
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
metric_ops[_summary_key(self._name,
@@ -911,8 +1009,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
- unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
- labels=labels, logits=logits)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=labels, logits=logits,
+ features=features, expected_loss_dim=1)
+ else:
+ unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
+ labels=labels, logits=logits)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits)
training_loss = losses.compute_weighted_loss(
@@ -924,8 +1027,31 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
processed_labels=labels)
def create_estimator_spec(
- self, features, mode, logits, labels=None, train_op_fn=None):
- """See `Head`."""
+ self, features, mode, logits, labels=None, train_op_fn=None,
+ regularization_losses=None):
+ """Returns an `EstimatorSpec`.
+
+ Args:
+ features: Input `dict` of `Tensor` or `SparseTensor` objects.
+ mode: Estimator's `ModeKeys`.
+ logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many
+ applications, the shape is `[batch_size, 1]`.
+ labels: Labels integer or string `Tensor` with shape matching `logits`,
+ namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
+ argument when `mode` equals `TRAIN` or `EVAL`.
+ train_op_fn: Function that takes a scalar loss `Tensor` and returns
+ `train_op`. Required in TRAIN mode.
+ regularization_losses: A list of additional scalar losses to be added to
+ the training loss, such as regularization losses. These losses are
+ usually expressed as a batch average, so for best results users need to
+ set `loss_reduction=SUM_OVER_BATCH_SIZE` or
+ `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
+ avoid scaling errors.
+ Returns:
+ `EstimatorSpec`.
+ Raises:
+ ValueError: If `train_op_fn` is `None` in TRAIN mode.
+ """
# Predict.
with ops.name_scope(self._name, 'head'):
with ops.name_scope(None, 'predictions', (logits,)):
@@ -972,20 +1098,28 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
(training_loss, unreduced_loss, weights, processed_labels) = (
self.create_loss(
features=features, mode=mode, logits=logits, labels=labels))
+ if regularization_losses:
+ regularization_loss = math_ops.add_n(regularization_losses)
+ regularized_training_loss = math_ops.add_n(
+ [training_loss, regularization_loss])
+ else:
+ regularization_loss = None
+ regularized_training_loss = training_loss
# Eval.
if mode == model_fn.ModeKeys.EVAL:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
- loss=training_loss,
+ loss=regularized_training_loss,
eval_metric_ops=self._eval_metric_ops(
labels=processed_labels,
logits=logits,
logistic=logistic,
class_ids=class_ids,
weights=weights,
- unreduced_loss=unreduced_loss))
+ unreduced_loss=unreduced_loss,
+ regularization_loss=regularization_loss))
# Train.
if train_op_fn is None:
@@ -999,24 +1133,29 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
else:
mean_loss = None
with ops.name_scope(''):
+ keys = metric_keys.MetricKeys
summary.scalar(
- _summary_key(self._name, metric_keys.MetricKeys.LOSS),
- training_loss)
+ _summary_key(self._name, keys.LOSS),
+ regularized_training_loss)
if mean_loss is not None:
summary.scalar(
- _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
- mean_loss)
+ _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
+ if regularization_loss is not None:
+ summary.scalar(
+ _summary_key(self._name, keys.LOSS_REGULARIZATION),
+ regularization_loss)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
- loss=training_loss,
- train_op=train_op_fn(training_loss))
+ loss=regularized_training_loss,
+ train_op=train_op_fn(regularized_training_loss))
def _regression_head_with_mean_squared_error_loss(
weight_column=None,
label_dimension=1,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for regression using the `mean_squared_error` loss.
@@ -1035,6 +1174,10 @@ def _regression_head_with_mean_squared_error_loss(
`[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
`[D0, D1, ... DN, label_dimension]`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, label_dimension]`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -1045,6 +1188,7 @@ def _regression_head_with_mean_squared_error_loss(
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -1057,10 +1201,13 @@ def _regression_head_with_mean_squared_error_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _RegressionHeadWithMeanSquaredErrorLoss(
weight_column=weight_column,
label_dimension=label_dimension,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -1072,6 +1219,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
label_dimension,
weight_column=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""`Head` for regression."""
if label_dimension < 1:
@@ -1079,6 +1227,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
self._logits_dimension = label_dimension
self._weight_column = weight_column
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -1097,8 +1246,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
labels=labels, logits=logits,
expected_labels_dimension=self._logits_dimension)
labels = math_ops.to_float(labels)
- unweighted_loss = losses.mean_squared_error(
- labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=labels, logits=logits,
+ features=features, expected_loss_dim=self._logits_dimension)
+ else:
+ unweighted_loss = losses.mean_squared_error(
+ labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits,
allow_per_logit_weights=True)
@@ -1111,7 +1265,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
processed_labels=labels)
def create_estimator_spec(
- self, features, mode, logits, labels=None, train_op_fn=None):
+ self, features, mode, logits, labels=None, train_op_fn=None,
+ regularization_losses=None):
"""Returns an `EstimatorSpec`.
Args:
@@ -1125,6 +1280,12 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
`mode` equals `TRAIN` or `EVAL`.
train_op_fn: Function that takes a scalar loss `Tensor` and returns
`train_op`. Required in TRAIN mode.
+ regularization_losses: A list of additional scalar losses to be added to
+ the training loss, such as regularization losses. These losses are
+ usually expressed as a batch average, so for best results users need to
+ set `loss_reduction=SUM_OVER_BATCH_SIZE` or
+ `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
+ avoid scaling errors.
Returns:
`EstimatorSpec`.
Raises:
@@ -1147,20 +1308,34 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
training_loss, unreduced_loss, weights, _ = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
+ if regularization_losses:
+ regularization_loss = math_ops.add_n(regularization_losses)
+ regularized_training_loss = math_ops.add_n(
+ [training_loss, regularization_loss])
+ else:
+ regularization_loss = None
+ regularized_training_loss = training_loss
# Eval.
if mode == model_fn.ModeKeys.EVAL:
+ keys = metric_keys.MetricKeys
# Estimator already adds a metric for loss.
eval_metric_ops = {
- _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN):
+ _summary_key(self._name, keys.LOSS_MEAN):
metrics_lib.mean(
values=unreduced_loss,
weights=weights)
}
+ if regularization_loss is not None:
+ regularization_loss_key = _summary_key(
+ self._name, keys.LOSS_REGULARIZATION)
+ eval_metric_ops[regularization_loss_key] = metrics_lib.mean(
+ values=regularization_loss,
+ name=keys.LOSS_REGULARIZATION)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
- loss=training_loss,
+ loss=regularized_training_loss,
eval_metric_ops=eval_metric_ops)
# Train.
@@ -1175,18 +1350,22 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
else:
mean_loss = None
with ops.name_scope(''):
+ keys = metric_keys.MetricKeys
summary.scalar(
- _summary_key(self._name, metric_keys.MetricKeys.LOSS),
- training_loss)
+ _summary_key(self._name, keys.LOSS),
+ regularized_training_loss)
if mean_loss is not None:
summary.scalar(
- _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN),
- mean_loss)
+ _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
+ if regularization_loss is not None:
+ summary.scalar(
+ _summary_key(self._name, keys.LOSS_REGULARIZATION),
+ regularization_loss)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
- loss=training_loss,
- train_op=train_op_fn(training_loss))
+ loss=regularized_training_loss,
+ train_op=train_op_fn(regularized_training_loss))
def _assert_range(labels, n_classes, message=None):
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 28b8e635fb..3a03770af4 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -111,6 +111,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
def test_invalid_logits_shape(self):
n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)
@@ -406,6 +441,56 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ logits_input = np.array([[-10., 10., 0.], [-15., 10., 0]], dtype=np.float32)
+ labels_input = np.array([[1], [2]], dtype=np.int64)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([1., 2.], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ logits = np.array([[-10., 10., 0.], [-15., 10., 0.]], dtype=np.float32)
+ labels = np.array([[1], [2]], dtype=np.int64)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 3\] \[loss_shape: \] \[2\]'):
+ actual_training_loss.eval()
+
def test_eval_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
@@ -487,7 +572,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def test_eval_with_regularization_losses(self):
n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
- n_classes, loss_reduction=losses.Reduction.MEAN)
+ n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
labels = np.array(((1,), (1,)), dtype=np.int64)
features = {'x': np.array(((42,),), dtype=np.int32)}
@@ -790,7 +875,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def test_train_with_regularization_losses(self):
n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
- n_classes, loss_reduction=losses.Reduction.MEAN)
+ n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
labels = np.array(((1,), (1,)), dtype=np.int64)
@@ -1204,6 +1289,41 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
def test_invalid_logits_shape(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
self.assertEqual(1, head.logits_dimension)
@@ -1485,6 +1605,53 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
]
self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
+ def test_eval_with_regularization_losses(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ logits = np.array(((45,), (-41,),), dtype=np.float32)
+ labels = np.array(((1,), (1,),), dtype=np.int32)
+ features = {'x': np.array(((42,),), dtype=np.int32)}
+ regularization_losses = [1.5, 0.5]
+ expected_regularization_loss = 2.
+ # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size
+ # = sum(0, 41) / 2 = 20.5
+ expected_unregularized_loss = 20.5
+ expected_regularized_loss = (
+ expected_unregularized_loss + expected_regularization_loss)
+
+ # Create estimator spec.
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels,
+ regularization_losses=regularization_losses)
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ keys.LOSS_MEAN: expected_unregularized_loss,
+ keys.LOSS_REGULARIZATION: expected_regularization_loss,
+ keys.ACCURACY: 1./2,
+ keys.PREDICTION_MEAN: 1./2,
+ keys.LABEL_MEAN: 2./2,
+ keys.ACCURACY_BASELINE: 2./2,
+ keys.AUC: 0.,
+ keys.AUC_PR: 1.,
+ }
+
+ # Assert predictions, loss, and metrics.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ loss, metrics = sess.run((spec.loss, update_ops))
+ self.assertAllClose(expected_regularized_loss, loss)
+ # Check results of both update (in `metrics`) and value ops.
+ self.assertAllClose(expected_metrics, metrics)
+ self.assertAllClose(
+ expected_metrics, {k: value_ops[k].eval() for k in value_ops})
+
def test_eval_with_vocabulary_list_create_loss(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
label_vocabulary=['aang', 'iroh'])
@@ -1652,6 +1819,56 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
self.assertAllClose(expected_weights, actual_weights)
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ logits_input = np.array([[-10.], [10.]], dtype=np.float32)
+ labels_input = np.array([[1], [0]], dtype=np.int64)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([1., 2.], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ logits = np.array([[-10.], [10.]], dtype=np.float32)
+ labels = np.array([[1], [0]], dtype=np.int64)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 1\] \[loss_shape: \] \[2\]'):
+ actual_training_loss.eval()
+
def test_train_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@@ -1749,6 +1966,49 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
},
summary_str)
+ def test_train_with_regularization_losses(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ logits = np.array(((45,), (-41,),), dtype=np.float32)
+ labels = np.array(((1,), (1,),), dtype=np.float64)
+ expected_train_result = b'my_train_op'
+ features = {'x': np.array(((42,),), dtype=np.float32)}
+ regularization_losses = [1.5, 0.5]
+ expected_regularization_loss = 2.
+ # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size
+ # = sum(0, 41) / 2 = 20.5
+ # loss = unregularized_loss + regularization_loss = 7.
+ expected_loss = 22.5
+ def _train_op_fn(loss):
+ with ops.control_dependencies((check_ops.assert_equal(
+ math_ops.to_float(expected_loss), math_ops.to_float(loss),
+ name='assert_loss'),)):
+ return constant_op.constant(expected_train_result)
+
+ # Create estimator spec.
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ regularization_losses=regularization_losses)
+
+ # Assert predictions, loss, train_op, and summaries.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNotNone(spec.scaffold.summary_op)
+ loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
+ self.assertAllClose(expected_loss, loss)
+ self.assertEqual(expected_train_result, train_result)
+ _assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ metric_keys.MetricKeys.LOSS_REGULARIZATION: (
+ expected_regularization_loss),
+ }, summary_str)
+
def test_float_labels_train_create_loss(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@@ -2265,6 +2525,37 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
head_lib._regression_head_with_mean_squared_error_loss(
loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
def test_invalid_logits(self):
head = head_lib._regression_head_with_mean_squared_error_loss(
label_dimension=3)
@@ -2440,6 +2731,56 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
# loss = [(43-45)^2, (44-41)] = [4, 9]
self.assertAllClose(13., training_loss.eval())
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[0., 1.], [2., 3.]], dtype=np.float32)
+ logits_input = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)
+ labels_input = np.array([[1., 0.], [2., -1.]], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=2, loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=2, loss_fn=_loss_fn)
+
+ logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)
+ labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 2\]\. \] '
+ r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2 1\]'):
+ actual_training_loss.eval()
+
def test_eval_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._regression_head_with_mean_squared_error_loss()
@@ -2512,6 +2853,51 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
]
self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
+ def test_eval_with_regularization_losses(self):
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ self.assertEqual(1, head.logits_dimension)
+
+ logits = np.array(((45,), (41,),), dtype=np.float32)
+ labels = np.array(((43,), (44,),), dtype=np.int32)
+ features = {'x': np.array(((42,),), dtype=np.float32)}
+ regularization_losses = [1.5, 0.5]
+ expected_regularization_loss = 2.
+ # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size
+ # = (4 + 9) / 2 = 6.5
+ expected_unregularized_loss = 6.5
+ expected_regularized_loss = (
+ expected_unregularized_loss + expected_regularization_loss)
+ # Create estimator spec.
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels,
+ regularization_losses=regularization_losses)
+
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ keys.LOSS_MEAN: expected_unregularized_loss,
+ keys.LOSS_REGULARIZATION: expected_regularization_loss,
+ }
+
+ # Assert predictions, loss, and metrics.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNone(spec.scaffold.summary_op)
+ value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
+ update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
+ prediction_key = prediction_keys.PredictionKeys.PREDICTIONS
+ predictions, loss, metrics = sess.run((
+ spec.predictions[prediction_key], spec.loss, update_ops))
+ self.assertAllClose(logits, predictions)
+ self.assertAllClose(expected_regularized_loss, loss)
+ # Check results of both update (in `metrics`) and value ops.
+ self.assertAllClose(expected_metrics, metrics)
+ self.assertAllClose(
+ expected_metrics, {k: value_ops[k].eval() for k in value_ops})
+
def test_train_create_loss(self):
head = head_lib._regression_head_with_mean_squared_error_loss()
logits = np.array(((45,), (41,),), dtype=np.float32)
@@ -2666,6 +3052,53 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
},
summary_str)
+ def test_train_with_regularization_losses(self):
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ self.assertEqual(1, head.logits_dimension)
+
+ # Create estimator spec.
+ logits = np.array(((45,), (41,),), dtype=np.float32)
+ labels = np.array(((43.,), (44.,),), dtype=np.float64)
+ expected_train_result = b'my_train_op'
+ features = {'x': np.array(((42.,),), dtype=np.float32)}
+ regularization_losses = [1.5, 0.5]
+ expected_regularization_loss = 2.
+ # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size
+ # = (4 + 9) / 2 = 6.5
+ # loss = unregularized_loss + regularization_loss = 8.5
+ expected_loss = 8.5
+ def _train_op_fn(loss):
+ with ops.control_dependencies((check_ops.assert_equal(
+ math_ops.to_float(expected_loss), math_ops.to_float(loss),
+ name='assert_loss'),)):
+ return constant_op.constant(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features=features,
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ regularization_losses=regularization_losses)
+
+ # Assert predictions, loss, train_op, and summaries.
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ self.assertIsNotNone(spec.scaffold.summary_op)
+ prediction_key = prediction_keys.PredictionKeys.PREDICTIONS
+ predictions, loss, train_result, summary_str = sess.run((
+ spec.predictions[prediction_key], spec.loss, spec.train_op,
+ spec.scaffold.summary_op))
+ self.assertAllClose(logits, predictions)
+ self.assertAllClose(expected_loss, loss)
+ self.assertEqual(expected_train_result, train_result)
+ _assert_simple_summaries(self, {
+ metric_keys.MetricKeys.LOSS: expected_loss,
+ metric_keys.MetricKeys.LOSS_REGULARIZATION: (
+ expected_regularization_loss),
+ }, summary_str)
+
def test_weighted_multi_example_eval(self):
"""1d label, 3 examples, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss(
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 97cfd24a10..a5b1172e72 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -23,7 +23,6 @@ import math
import six
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.feature_column import feature_column as feature_column_lib
@@ -31,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
from tensorflow.python.training import training_util
@@ -245,7 +245,8 @@ class LinearClassifier(estimator.Estimator):
optimizer='Ftrl',
config=None,
partitioner=None,
- warm_start_from=None):
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
"""Construct a `LinearClassifier` estimator object.
Args:
@@ -282,6 +283,8 @@ class LinearClassifier(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights and biases are warm-started, and it is assumed that vocabularies
and Tensor names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
Returns:
A `LinearClassifier` estimator.
@@ -292,15 +295,17 @@ class LinearClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
- """Call the defined shared _linear_model_fn and possibly warm-start."""
- estimator_spec = _linear_model_fn(
+ """Call the defined shared _linear_model_fn."""
+ return _linear_model_fn(
features=features,
labels=labels,
mode=mode,
@@ -309,19 +314,12 @@ class LinearClassifier(estimator.Estimator):
optimizer=optimizer,
partitioner=partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(LinearClassifier, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
- config=config)
+ config=config,
+ warm_start_from=warm_start_from)
class LinearRegressor(estimator.Estimator):
@@ -388,7 +386,8 @@ class LinearRegressor(estimator.Estimator):
optimizer='Ftrl',
config=None,
partitioner=None,
- warm_start_from=None):
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
"""Initializes a `LinearRegressor` instance.
Args:
@@ -417,13 +416,16 @@ class LinearRegressor(estimator.Estimator):
string filepath is provided instead of a `WarmStartSettings`, then all
weights and biases are warm-started, and it is assumed that vocabularies
and Tensor names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM`.
"""
head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
- label_dimension=label_dimension, weight_column=weight_column)
+ label_dimension=label_dimension, weight_column=weight_column,
+ loss_reduction=loss_reduction)
def _model_fn(features, labels, mode, config):
- """Call the defined shared _linear_model_fn and possibly warm-start."""
- estimator_spec = _linear_model_fn(
+ """Call the defined shared _linear_model_fn."""
+ return _linear_model_fn(
features=features,
labels=labels,
mode=mode,
@@ -432,16 +434,9 @@ class LinearRegressor(estimator.Estimator):
optimizer=optimizer,
partitioner=partitioner,
config=config)
- # pylint: disable=protected-access
- warm_start_settings = warm_starting_util._get_default_warm_start_settings(
- warm_start_from)
- if warm_start_settings:
- warm_starting_util._warm_start(warm_start_settings)
- # pylint: enable=protected-access
-
- return estimator_spec
super(LinearRegressor, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
- config=config)
+ config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index cccb9af4b2..3e9183cf1b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -2003,7 +2003,7 @@ class BaseLinearWarmStartingTest(object):
# Create a second LinearClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
- # accumulator values that change). Use a a new FeatureColumn with a
+ # accumulator values that change). Use a new FeatureColumn with a
# different vocabulary for occupation.
new_vocab_list = ['doctor', 'consultant', 'engineer']
new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 90eecc1fda..96555b5e03 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -35,6 +35,7 @@ from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util
+from tensorflow.python.estimator import warm_starting_util
from tensorflow.python.estimator.export.export import build_all_signature_defs
from tensorflow.python.estimator.export.export import get_temp_export_dir
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
@@ -54,6 +55,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
+from tensorflow.python.util import compat_internal
from tensorflow.python.util import nest
@@ -96,9 +98,22 @@ class Estimator(object):
@end_compatibility
"""
- def __init__(self, model_fn, model_dir=None, config=None, params=None):
+ def __init__(self, model_fn, model_dir=None, config=None, params=None,
+ warm_start_from=None):
"""Constructs an `Estimator` instance.
+ See @{$estimators} for more information. To warm-start an `Estimator`:
+
+ ```python
+ estimator = tf.estimator.DNNClassifier(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ warm_start_from="/path/to/checkpoint/dir")
+ ```
+
+ For more details on warm-start configuration, see
+ @{tf.estimator.WarmStartSettings$WarmStartSettings}.
+
Args:
model_fn: Model function. Follows the signature:
@@ -128,12 +143,19 @@ class Estimator(object):
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
- continue training a previously saved model. If `None`, the model_dir in
- `config` will be used if set. If both are set, they must be same. If
- both are `None`, a temporary directory will be used.
+ continue training a previously saved model. If `PathLike` object, the
+ path will be resolved. If `None`, the model_dir in `config` will be used
+ if set. If both are set, they must be same. If both are `None`, a
+ temporary directory will be used.
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
+ warm_start_from: Optional string filepath to a checkpoint to warm-start
+ from, or a `tf.estimator.WarmStartSettings` object to
+ fully configure warm-starting. If the string filepath is
+ provided instead of a `WarmStartSettings`, then all
+ variables are warm-started, and it is assumed that
+ vocabularies and Tensor names are unchanged.
Raises:
RuntimeError: If eager execution is enabled.
@@ -158,6 +180,7 @@ class Estimator(object):
self._config = config
# Model directory.
+ model_dir = compat_internal.path_to_str(model_dir)
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
# TODO(alanyee): remove this suppression after it is no longer needed
@@ -190,6 +213,11 @@ class Estimator(object):
self._model_fn = model_fn
self._params = copy.deepcopy(params or {})
+ # pylint: disable=protected-access
+ self._warm_start_settings = (
+ warm_starting_util._get_default_warm_start_settings(warm_start_from))
+ # pylint: enable=protected-access
+
@property
def model_dir(self):
return self._model_dir
@@ -453,6 +481,7 @@ class Estimator(object):
with training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
+ master=self._config.master,
scaffold=estimator_spec.scaffold,
config=self._session_config),
hooks=input_hooks + hooks) as mon_sess:
@@ -778,6 +807,13 @@ class Estimator(object):
worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
+
+ if self._warm_start_settings:
+ logging.info('Warm-starting with WarmStartSettings: %s' %
+ (self._warm_start_settings,))
+ # pylint: disable=protected-access
+ warm_starting_util._warm_start(self._warm_start_settings)
+ # pylint: enable=protected-access
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
# make another one with the name 'loss' to ensure it shows up in the right
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index ed1676a92d..833f3dcac3 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -52,6 +52,7 @@ from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
@@ -629,6 +630,33 @@ class EstimatorTrainTest(test.TestCase):
self.assertEqual(
10, estimator._load_global_step_from_checkpoint_dir(est.model_dir))
+ def test_warm_starts(self):
+ def _make_model_fn(x):
+ def _variable_creating_model_fn(features, labels, mode):
+ _, _ = features, labels
+ variable_scope.get_variable('x', initializer=x)
+ global_step = training.get_global_step()
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(global_step, 1))
+ return _variable_creating_model_fn
+
+ est = estimator.Estimator(model_fn=_make_model_fn(42.))
+ est.train(dummy_input_fn, steps=10)
+
+ warm_started_est = estimator.Estimator(
+ model_fn=_make_model_fn(36.),
+ warm_start_from=est.model_dir)
+ warm_started_est.train(dummy_input_fn, steps=5)
+ # warm_start is called after the model_fn, so x should have the value
+ # from the checkpoint.
+ self.assertEqual(42., warm_started_est.get_variable_value('x'))
+ # global_step should not be warm-started.
+ self.assertEqual(
+ 5, estimator._load_global_step_from_checkpoint_dir(
+ warm_started_est.model_dir))
+
def test_max_step(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
est.train(dummy_input_fn, max_steps=5)
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 75c0e61d47..8e5d8141a1 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -47,10 +47,9 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
- """
- Recursively fills padded arr with elements from seq.
- If length of seq is less than arr padded length, fillvalue used.
+ """Recursively fills padded arr with elements from seq.
+ If length of seq is less than arr padded length, fillvalue used.
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
seq: Non-padded list of data sampels of shape
@@ -84,28 +83,30 @@ def _pad_if_needed(batch_key_item, fillvalue=0):
Raises:
ValueError if data samples have different shapes (except last padded dim).
"""
- shapes = [seq.shape[:-1] if len(seq.shape) > 0 else -1
- for seq in batch_key_item]
+ shapes = [
+ seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
+ ]
if not all(shapes[0] == x for x in shapes):
raise ValueError("Array shapes must match.")
- last_length = [seq.shape[-1] if len(seq.shape) > 0 else 0
- for seq in batch_key_item]
+ last_length = [
+ seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
+ ]
if all([x == last_length[0] for x in last_length]):
return batch_key_item
batch_size = len(batch_key_item)
max_sequence_length = max(last_length)
result_batch = np.zeros(
- shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
- dtype=batch_key_item[0].dtype)
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
_fill_array(result_batch, batch_key_item, fillvalue)
return result_batch
-def _get_integer_indices_for_next_batch(
- batch_indices_start, batch_size, epoch_end, array_length,
- current_epoch, total_epochs):
+def _get_integer_indices_for_next_batch(batch_indices_start, batch_size,
+ epoch_end, array_length, current_epoch,
+ total_epochs):
"""Returns the integer indices for next batch.
If total epochs is not None and current epoch is the final epoch, the end
@@ -135,8 +136,9 @@ def _get_integer_indices_for_next_batch(
"Already emitted %s epochs." % current_epoch)
batch_indices_end = batch_indices_start + batch_size
- batch_indices = [j % array_length for j in
- range(batch_indices_start, batch_indices_end)]
+ batch_indices = [
+ j % array_length for j in range(batch_indices_start, batch_indices_end)
+ ]
epoch_end_indices = [i for i, x in enumerate(batch_indices) if x == epoch_end]
current_epoch += len(epoch_end_indices)
@@ -320,16 +322,20 @@ class _GeneratorFeedFn(object):
raise KeyError("key mismatch between dicts emitted by GenFun "
"Expected {} keys; got {}".format(
self._keys, data_row.keys()))
- list_dict.setdefault(self._col_placeholders[index],
- list()).append(data_row[key])
+ list_dict.setdefault(self._col_placeholders[index], list()).append(
+ data_row[key])
list_dict_size += 1
if self._pad_value is not None:
- feed_dict = {key: np.asarray(_pad_if_needed(item, self._pad_value))
- for key, item in list(list_dict.items())}
+ feed_dict = {
+ key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())
+ }
else:
- feed_dict = {key: np.asarray(item)
- for key, item in list(list_dict.items())}
+ feed_dict = {
+ key: np.asarray(item)
+ for key, item in list(list_dict.items())
+ }
return feed_dict
@@ -382,9 +388,8 @@ def _enqueue_data(data,
queue_shapes = [(), data.shape[1:]]
get_feed_fn = _ArrayFeedFn
elif isinstance(data, collections.OrderedDict):
- types = [dtypes.int64] + [
- dtypes.as_dtype(col.dtype) for col in data.values()
- ]
+ types = [dtypes.int64
+ ] + [dtypes.as_dtype(col.dtype) for col in data.values()]
queue_shapes = [()] + [col.shape[1:] for col in data.values()]
get_feed_fn = _OrderedDictNumpyFeedFn
elif isinstance(data, tp.FunctionType):
@@ -447,11 +452,11 @@ def _enqueue_data(data,
seed=seed)
elif pad_data:
min_after_dequeue = 0 # just for the summary text
- queue_shapes = list(map(
- lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
- queue_shapes))
+ queue_shapes = list(
+ map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
queue = data_flow_ops.PaddingFIFOQueue(
- capacity, dtypes=types, shapes=queue_shapes)
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -470,31 +475,35 @@ def _enqueue_data(data,
if not pad_data:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
else:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs,
- pad_value=pad_value))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
- queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)
+ queue=queue,
+ enqueue_ops=enqueue_ops,
+ feed_fns=feed_fns)
queue_runner.add_queue_runner(runner)
- full = (math_ops.cast(
- math_ops.maximum(0, queue.size() - min_after_dequeue),
- dtypes.float32) * (1. / (capacity - min_after_dequeue)))
+ full = (
+ math_ops.cast(
+ math_ops.maximum(0,
+ queue.size() - min_after_dequeue), dtypes.float32)
+ * (1. / (capacity - min_after_dequeue)))
# Note that name contains a '/' at the end so we intentionally do not place
# a '/' after %s below.
summary_name = ("queue/%sfraction_over_%d_of_%d_full" %
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index dc714d4d22..e446b3e03a 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -27,6 +27,8 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+from tensorflow.python.util import compat
+from tensorflow.python.util import compat_internal
_USE_DEFAULT = object()
@@ -399,7 +401,8 @@ class RunConfig(object):
Args:
model_dir: directory where model parameters, graph, etc are saved. If
- `None`, will use a default value set by the Estimator.
+ `PathLike` object, the path will be resolved. If `None`, will use a
+ default value set by the Estimator.
tf_random_seed: Random seed for TensorFlow initializers.
Setting this value allows consistency between reruns.
save_summary_steps: Save summaries every this many steps.
@@ -442,7 +445,8 @@ class RunConfig(object):
if tf_config:
logging.info('TF_CONFIG environment variable: %s', tf_config)
- model_dir = _get_model_dir(tf_config, model_dir)
+ model_dir = _get_model_dir(tf_config,
+ compat_internal.path_to_str(model_dir))
RunConfig._replace(
self,
@@ -484,7 +488,7 @@ class RunConfig(object):
self._num_ps_replicas = _count_ps(self._cluster_spec)
self._num_worker_replicas = _count_worker(
self._cluster_spec, chief_task_type=TaskType.CHIEF)
- self._global_id = _get_global_id_in_cluster(
+ self._global_id_in_cluster = _get_global_id_in_cluster(
self._cluster_spec,
self._task_type,
self._task_id,
@@ -495,14 +499,14 @@ class RunConfig(object):
self._master = _LOCAL_MASTER
self._num_ps_replicas = 0
self._num_worker_replicas = 0
- self._global_id = None # undefined
+ self._global_id_in_cluster = None # undefined
self._is_chief = self._task_type == TaskType.CHIEF
else:
# Local mode.
self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER)
self._task_id = int(task_env.get(_TASK_ID_KEY, 0))
- self._global_id = 0
+ self._global_id_in_cluster = 0
if self._task_type != TaskType.WORKER:
raise ValueError(
@@ -537,7 +541,7 @@ class RunConfig(object):
raise ValueError('If `master` node exists in `cluster`, task_type '
'`evaluator` is not supported.')
- self._global_id = _get_global_id_in_cluster(
+ self._global_id_in_cluster = _get_global_id_in_cluster(
self._cluster_spec,
self._task_type,
self._task_id,
@@ -619,7 +623,7 @@ class RunConfig(object):
Returns:
An integer id.
"""
- return self._global_id
+ return self._global_id_in_cluster
@property
def task_type(self):
diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py
index c748b318b7..ad95c71234 100644
--- a/tensorflow/python/estimator/warm_starting_util.py
+++ b/tensorflow/python/estimator/warm_starting_util.py
@@ -402,10 +402,10 @@ def _warm_start_var_with_vocab(var,
def _warm_start(warm_start_settings):
- """Warmstarts a model using the given settings.
+ """Warm-starts a model using the given settings.
- Currently, this is intended for use only in canned Estimators. Once made
- public, it can be used in any model_fn.
+ If you are using a tf.estimator.Estimator, this will automatically be called
+ during training.
Args:
warm_start_settings: An object of `WarmStartSettings`.
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 76d44fc474..a758f8a4fc 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -85,6 +85,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:partitioned_variables",
@@ -93,6 +94,8 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
"//tensorflow/python/estimator:numpy_io",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index a7fe528ee1..7feb209cc4 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -657,11 +657,11 @@ def embedding_column(
trainable=trainable)
-def _shared_embedding_columns(
+def shared_embedding_columns(
categorical_columns, dimension, combiner='mean', initializer=None,
shared_embedding_collection_name=None, ckpt_to_load_from=None,
tensor_name_in_ckpt=None, max_norm=None, trainable=True):
- """List of `_DenseColumn`s that convert from sparse, categorical input.
+ """List of dense columns that convert from sparse, categorical input.
This is similar to `embedding_column`, except that that it produces a list of
embedding columns that share the same embedding weights.
@@ -670,7 +670,7 @@ def _shared_embedding_columns(
impression video IDs that share the same vocabulary), and you want to convert
them to a dense representation (e.g., to feed to a DNN).
- Inputs must be a list of `_CategoricalColumn` created by any of the
+ Inputs must be a list of categorical columns created by any of the
`categorical_column_*` function. They must all be of the same type and have
the same arguments except `key`. E.g. they can be
categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
@@ -714,7 +714,7 @@ def _shared_embedding_columns(
```
Args:
- categorical_columns: List of `_CategoricalColumn`s created by a
+ categorical_columns: List of categorical columns created by a
`categorical_column_with_*` function. These columns produce the sparse IDs
that are inputs to the embedding lookup. All columns must be of the same
type and have the same arguments except `key`. E.g. they can be
@@ -744,7 +744,7 @@ def _shared_embedding_columns(
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
- A list of `_DenseColumn`s that converts from sparse input. The order of
+ A list of dense columns that converts from sparse input. The order of
results follows the ordering of `categorical_columns`.
Raises:
diff --git a/tensorflow/python/feature_column/feature_column_lib.py b/tensorflow/python/feature_column/feature_column_lib.py
index 8a57986764..505a1408d2 100644
--- a/tensorflow/python/feature_column/feature_column_lib.py
+++ b/tensorflow/python/feature_column/feature_column_lib.py
@@ -29,6 +29,7 @@ _allowed_symbols = [
'linear_model',
'make_parse_example_spec',
'embedding_column',
+ 'shared_embedding_columns',
'crossed_column',
'numeric_column',
'bucketized_column',
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 2374680b96..6f366e7722 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -29,7 +29,6 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_lib
from tensorflow.python.feature_column import feature_column_lib as fc
from tensorflow.python.feature_column.feature_column import _CategoricalColumn
from tensorflow.python.feature_column.feature_column import _DenseColumn
@@ -1072,6 +1071,7 @@ def get_linear_model_column_var(column):
'linear_model/' + column.name)[0]
+@test_util.with_c_api
class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
@@ -1325,10 +1325,16 @@ class LinearModelTest(test.TestCase):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- with _initialized_session():
- with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
- predictions.eval()
+ if ops._USE_C_API:
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ predictions = fc.linear_model(features, [price])
+ else:
+ predictions = fc.linear_model(features, [price])
+ with _initialized_session():
+ with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
+ predictions.eval()
def test_dense_reshaping(self):
price = fc.numeric_column('price', shape=[1, 2])
@@ -1791,6 +1797,7 @@ class InputLayerTest(test.TestCase):
self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+@test_util.with_c_api
class FunctionalInputLayerTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
@@ -1855,10 +1862,16 @@ class FunctionalInputLayerTest(test.TestCase):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- net = fc.input_layer(features, [price])
- with _initialized_session():
- with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
- net.eval()
+ if ops._USE_C_API:
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ net = fc.input_layer(features, [price])
+ else:
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
+ net.eval()
def test_reshaping(self):
price = fc.numeric_column('price', shape=[1, 2])
@@ -4137,7 +4150,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_lib._shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
@@ -4183,7 +4196,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -4236,7 +4249,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- original_a, _ = fc_lib._shared_embedding_columns(
+ original_a, _ = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -4274,7 +4287,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
- fc_lib._shared_embedding_columns(
+ fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b], dimension=2,
initializer='not_fn')
@@ -4289,7 +4302,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
ValueError,
'all categorical_columns must have the same type.*'
'_IdentityCategoricalColumn.*_HashedCategoricalColumn'):
- fc_lib._shared_embedding_columns(
+ fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b, categorical_column_c],
dimension=2)
@@ -4302,11 +4315,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='bbb', num_buckets=3)
weighted_categorical_column_b = fc.weighted_categorical_column(
categorical_column_b, weight_feature_key='bbb_weights')
- fc_lib._shared_embedding_columns(
+ fc.shared_embedding_columns(
[weighted_categorical_column_a, categorical_column_b], dimension=2)
- fc_lib._shared_embedding_columns(
+ fc.shared_embedding_columns(
[categorical_column_a, weighted_categorical_column_b], dimension=2)
- fc_lib._shared_embedding_columns(
+ fc.shared_embedding_columns(
[weighted_categorical_column_a, weighted_categorical_column_b],
dimension=2)
@@ -4315,7 +4328,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
b = fc.categorical_column_with_vocabulary_list(
key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
- a_embedded, b_embedded = fc_lib._shared_embedding_columns(
+ a_embedded, b_embedded = fc.shared_embedding_columns(
[a, b], dimension=2)
data = example_pb2.Example(features=feature_pb2.Features(
feature={
@@ -4350,7 +4363,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
def test_transform_feature(self):
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
- a_embedded, b_embedded = fc_lib._shared_embedding_columns(
+ a_embedded, b_embedded = fc.shared_embedding_columns(
[a, b], dimension=2)
features = {
'aaa': sparse_tensor.SparseTensor(
@@ -4420,7 +4433,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension, initializer=_initializer)
@@ -4482,7 +4495,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension, initializer=_initializer)
@@ -4522,7 +4535,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension, initializer=_initializer)
@@ -4628,7 +4641,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension, initializer=_initializer,
trainable=trainable)
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index ac915157f5..d3d8c9c154 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -52,6 +52,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.util.tf_export import tf_export
def _eager_reshape(tensor, shape, ctx):
@@ -59,7 +60,6 @@ def _eager_reshape(tensor, shape, ctx):
attr_t = tensor._datatype_enum() # pylint: disable=protected-access
attr_tshape, (shape,) = execute.args_to_matching_eager(
[shape], ctx, dtypes.int32)
- attr_tshape = attr_tshape
inputs_flat = [tensor, shape]
attrs = ("T", attr_t, "Tshape", attr_tshape)
result, = execute.execute(
@@ -131,6 +131,7 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
return ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
+@tf_export("constant")
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
diff --git a/tensorflow/python/framework/cpp_shape_inference.h b/tensorflow/python/framework/cpp_shape_inference.h
index afca7277c7..c6ab6b106f 100644
--- a/tensorflow/python/framework/cpp_shape_inference.h
+++ b/tensorflow/python/framework/cpp_shape_inference.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
@@ -51,4 +51,4 @@ std::vector<string> RunCppShapeInference(
} // namespace swig
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
+#endif // TENSORFLOW_PYTHON_FRAMEWORK_CPP_SHAPE_INFERENCE_H_
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 8f5125dcfe..ab06a2babf 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import copy
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("DeviceSpec")
class DeviceSpec(object):
"""Represents a (possibly partial) specification for a TensorFlow device.
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index b0422eb6be..67ccf990d6 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -23,11 +23,13 @@ import numpy as np
from tensorflow.core.framework import types_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.tf_export import tf_export
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+@tf_export("DType")
class DType(object):
"""Represents the type of the elements in a `Tensor`.
@@ -321,32 +323,55 @@ dtype_range = {np.bool_: (False, True),
# Define standard wrappers for the types_pb2.DataType enum.
resource = DType(types_pb2.DT_RESOURCE)
+tf_export("resource").export_constant(__name__, "resource")
variant = DType(types_pb2.DT_VARIANT)
+tf_export("variant").export_constant(__name__, "variant")
float16 = DType(types_pb2.DT_HALF)
+tf_export("float16").export_constant(__name__, "float16")
half = float16
+tf_export("half").export_constant(__name__, "half")
float32 = DType(types_pb2.DT_FLOAT)
+tf_export("float32").export_constant(__name__, "float32")
float64 = DType(types_pb2.DT_DOUBLE)
+tf_export("float64").export_constant(__name__, "float64")
double = float64
+tf_export("double").export_constant(__name__, "double")
int32 = DType(types_pb2.DT_INT32)
+tf_export("int32").export_constant(__name__, "int32")
uint8 = DType(types_pb2.DT_UINT8)
+tf_export("uint8").export_constant(__name__, "uint8")
uint16 = DType(types_pb2.DT_UINT16)
+tf_export("uint16").export_constant(__name__, "uint16")
uint32 = DType(types_pb2.DT_UINT32)
uint64 = DType(types_pb2.DT_UINT64)
int16 = DType(types_pb2.DT_INT16)
+tf_export("int16").export_constant(__name__, "int16")
int8 = DType(types_pb2.DT_INT8)
+tf_export("int8").export_constant(__name__, "int8")
string = DType(types_pb2.DT_STRING)
+tf_export("string").export_constant(__name__, "string")
complex64 = DType(types_pb2.DT_COMPLEX64)
+tf_export("complex64").export_constant(__name__, "complex64")
complex128 = DType(types_pb2.DT_COMPLEX128)
+tf_export("complex128").export_constant(__name__, "complex128")
int64 = DType(types_pb2.DT_INT64)
+tf_export("int64").export_constant(__name__, "int64")
bool = DType(types_pb2.DT_BOOL)
+tf_export("bool").export_constant(__name__, "bool")
qint8 = DType(types_pb2.DT_QINT8)
+tf_export("qint8").export_constant(__name__, "qint8")
quint8 = DType(types_pb2.DT_QUINT8)
+tf_export("quint8").export_constant(__name__, "quint8")
qint16 = DType(types_pb2.DT_QINT16)
+tf_export("qint16").export_constant(__name__, "qint16")
quint16 = DType(types_pb2.DT_QUINT16)
+tf_export("quint16").export_constant(__name__, "quint16")
qint32 = DType(types_pb2.DT_QINT32)
+tf_export("qint32").export_constant(__name__, "qint32")
resource_ref = DType(types_pb2.DT_RESOURCE_REF)
variant_ref = DType(types_pb2.DT_VARIANT_REF)
bfloat16 = DType(types_pb2.DT_BFLOAT16)
+tf_export("bfloat16").export_constant(__name__, "bfloat16")
float16_ref = DType(types_pb2.DT_HALF_REF)
half_ref = float16_ref
float32_ref = DType(types_pb2.DT_FLOAT_REF)
@@ -578,8 +603,10 @@ _TF_TO_NP = {
QUANTIZED_DTYPES = frozenset(
[qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
quint16_ref, qint32_ref])
+tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
+@tf_export("as_dtype")
def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index c3b2c498c3..2a40316d51 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -25,8 +25,10 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("OpError", "errors.OpError")
class OpError(Exception):
"""A generic error that is raised when TensorFlow execution fails.
@@ -133,25 +135,48 @@ class OpError(Exception):
OK = error_codes_pb2.OK
+tf_export("errors.OK").export_constant(__name__, "OK")
CANCELLED = error_codes_pb2.CANCELLED
+tf_export("errors.CANCELLED").export_constant(__name__, "CANCELLED")
UNKNOWN = error_codes_pb2.UNKNOWN
+tf_export("errors.UNKNOWN").export_constant(__name__, "UNKNOWN")
INVALID_ARGUMENT = error_codes_pb2.INVALID_ARGUMENT
+tf_export("errors.INVALID_ARGUMENT").export_constant(__name__,
+ "INVALID_ARGUMENT")
DEADLINE_EXCEEDED = error_codes_pb2.DEADLINE_EXCEEDED
+tf_export("errors.DEADLINE_EXCEEDED").export_constant(__name__,
+ "DEADLINE_EXCEEDED")
NOT_FOUND = error_codes_pb2.NOT_FOUND
+tf_export("errors.NOT_FOUND").export_constant(__name__, "NOT_FOUND")
ALREADY_EXISTS = error_codes_pb2.ALREADY_EXISTS
+tf_export("errors.ALREADY_EXISTS").export_constant(__name__, "ALREADY_EXISTS")
PERMISSION_DENIED = error_codes_pb2.PERMISSION_DENIED
+tf_export("errors.PERMISSION_DENIED").export_constant(__name__,
+ "PERMISSION_DENIED")
UNAUTHENTICATED = error_codes_pb2.UNAUTHENTICATED
+tf_export("errors.UNAUTHENTICATED").export_constant(__name__, "UNAUTHENTICATED")
RESOURCE_EXHAUSTED = error_codes_pb2.RESOURCE_EXHAUSTED
+tf_export("errors.RESOURCE_EXHAUSTED").export_constant(__name__,
+ "RESOURCE_EXHAUSTED")
FAILED_PRECONDITION = error_codes_pb2.FAILED_PRECONDITION
+tf_export("errors.FAILED_PRECONDITION").export_constant(__name__,
+ "FAILED_PRECONDITION")
ABORTED = error_codes_pb2.ABORTED
+tf_export("errors.ABORTED").export_constant(__name__, "ABORTED")
OUT_OF_RANGE = error_codes_pb2.OUT_OF_RANGE
+tf_export("errors.OUT_OF_RANGE").export_constant(__name__, "OUT_OF_RANGE")
UNIMPLEMENTED = error_codes_pb2.UNIMPLEMENTED
+tf_export("errors.UNIMPLEMENTED").export_constant(__name__, "UNIMPLEMENTED")
INTERNAL = error_codes_pb2.INTERNAL
+tf_export("errors.INTERNAL").export_constant(__name__, "INTERNAL")
UNAVAILABLE = error_codes_pb2.UNAVAILABLE
+tf_export("errors.UNAVAILABLE").export_constant(__name__, "UNAVAILABLE")
DATA_LOSS = error_codes_pb2.DATA_LOSS
+tf_export("errors.DATA_LOSS").export_constant(__name__, "DATA_LOSS")
# pylint: disable=line-too-long
+@tf_export("errors.CancelledError")
class CancelledError(OpError):
"""Raised when an operation or step is cancelled.
@@ -172,6 +197,7 @@ class CancelledError(OpError):
# pylint: enable=line-too-long
+@tf_export("errors.UnknownError")
class UnknownError(OpError):
"""Unknown error.
@@ -189,6 +215,7 @@ class UnknownError(OpError):
super(UnknownError, self).__init__(node_def, op, message, error_code)
+@tf_export("errors.InvalidArgumentError")
class InvalidArgumentError(OpError):
"""Raised when an operation receives an invalid argument.
@@ -209,6 +236,7 @@ class InvalidArgumentError(OpError):
INVALID_ARGUMENT)
+@tf_export("errors.DeadlineExceededError")
class DeadlineExceededError(OpError):
"""Raised when a deadline expires before an operation could complete.
@@ -223,6 +251,7 @@ class DeadlineExceededError(OpError):
DEADLINE_EXCEEDED)
+@tf_export("errors.NotFoundError")
class NotFoundError(OpError):
"""Raised when a requested entity (e.g., a file or directory) was not found.
@@ -239,6 +268,7 @@ class NotFoundError(OpError):
super(NotFoundError, self).__init__(node_def, op, message, NOT_FOUND)
+@tf_export("errors.AlreadyExistsError")
class AlreadyExistsError(OpError):
"""Raised when an entity that we attempted to create already exists.
@@ -256,6 +286,7 @@ class AlreadyExistsError(OpError):
ALREADY_EXISTS)
+@tf_export("errors.PermissionDeniedError")
class PermissionDeniedError(OpError):
"""Raised when the caller does not have permission to run an operation.
@@ -273,6 +304,7 @@ class PermissionDeniedError(OpError):
PERMISSION_DENIED)
+@tf_export("errors.UnauthenticatedError")
class UnauthenticatedError(OpError):
"""The request does not have valid authentication credentials.
@@ -287,6 +319,7 @@ class UnauthenticatedError(OpError):
UNAUTHENTICATED)
+@tf_export("errors.ResourceExhaustedError")
class ResourceExhaustedError(OpError):
"""Some resource has been exhausted.
@@ -302,6 +335,7 @@ class ResourceExhaustedError(OpError):
RESOURCE_EXHAUSTED)
+@tf_export("errors.FailedPreconditionError")
class FailedPreconditionError(OpError):
"""Operation was rejected because the system is not in a state to execute it.
@@ -318,6 +352,7 @@ class FailedPreconditionError(OpError):
FAILED_PRECONDITION)
+@tf_export("errors.AbortedError")
class AbortedError(OpError):
"""The operation was aborted, typically due to a concurrent action.
@@ -335,6 +370,7 @@ class AbortedError(OpError):
super(AbortedError, self).__init__(node_def, op, message, ABORTED)
+@tf_export("errors.OutOfRangeError")
class OutOfRangeError(OpError):
"""Raised when an operation iterates past the valid input range.
@@ -353,6 +389,7 @@ class OutOfRangeError(OpError):
OUT_OF_RANGE)
+@tf_export("errors.UnimplementedError")
class UnimplementedError(OpError):
"""Raised when an operation has not been implemented.
@@ -371,6 +408,7 @@ class UnimplementedError(OpError):
UNIMPLEMENTED)
+@tf_export("errors.InternalError")
class InternalError(OpError):
"""Raised when the system experiences an internal error.
@@ -385,6 +423,7 @@ class InternalError(OpError):
super(InternalError, self).__init__(node_def, op, message, INTERNAL)
+@tf_export("errors.UnavailableError")
class UnavailableError(OpError):
"""Raised when the runtime is currently unavailable.
@@ -399,6 +438,7 @@ class UnavailableError(OpError):
UNAVAILABLE)
+@tf_export("errors.DataLossError")
class DataLossError(OpError):
"""Raised when unrecoverable data loss or corruption is encountered.
@@ -437,10 +477,12 @@ _EXCEPTION_CLASS_TO_CODE = dict((
(class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
+@tf_export("errors.exception_type_from_error_code")
def exception_type_from_error_code(error_code):
return _CODE_TO_EXCEPTION_CLASS[error_code]
+@tf_export("errors.error_code_from_exception_type")
def error_code_from_exception_type(cls):
return _EXCEPTION_CLASS_TO_CODE[cls]
@@ -457,7 +499,8 @@ def _make_specific_exception(node_def, op, message, error_code):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-class raise_exception_on_not_ok_status(object): # pylint: disable=invalid-name
+@tf_export("errors.raise_exception_on_not_ok_status") # pylint: disable=invalid-name
+class raise_exception_on_not_ok_status(object):
"""Context manager to check for C API status."""
def __enter__(self):
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index e06899f81d..cba225e749 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -417,7 +417,7 @@ class _DefinedFunction(object):
if self._func_name:
assert self._func_name == self._op_def.name
else:
- self._func_name = self._op_def.name
+ self._func_name = compat.as_str(self._op_def.name)
def _set_c_attrs(self, attrs):
"""Sets `attrs` as attributes of self._c_func.
@@ -682,7 +682,7 @@ class _FuncGraph(ops.Graph):
def create_op(self, op_type, inputs, data_types, **kwargs):
for i, x in enumerate(inputs):
- if x.graph is not self:
+ if isinstance(x, ops.EagerTensor) or x.graph is not self:
# Referring to a tensor from other graph.
if x in self._captured:
# Captured already.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 57e5a724c9..a4ca3f9a89 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.core.framework import function_pb2
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -451,13 +452,17 @@ class FunctionTest(test.TestCase):
lambda y: AssertFail(y), [x])
# pylint: enable=unnecessary-lambda
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
# Enables inlining.
- config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L0,
- do_common_subexpression_elimination=True,
- do_function_inlining=True,
- do_constant_folding=True)))
+ config = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0,
+ do_common_subexpression_elimination=True,
+ do_function_inlining=True,
+ do_constant_folding=True),
+ rewrite_options=rewriter_config))
with session.Session(config=config) as sess:
# Since the 'False' branch is not taken, the assertion should not fire.
diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py
index a0ea4ad48e..be30b16f5f 100644
--- a/tensorflow/python/framework/graph_io.py
+++ b/tensorflow/python/framework/graph_io.py
@@ -24,8 +24,10 @@ import os.path
from google.protobuf import text_format
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('train.write_graph')
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
"""Writes a graph proto to a file.
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 6c7b455388..5a543317e6 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
_VARIABLE_OPS = {
"Assign",
@@ -49,6 +50,7 @@ def _is_variable_op(op):
return op in _VARIABLE_OPS
+@tf_export("graph_util.must_run_on_cpu")
def must_run_on_cpu(node, pin_variables_on_cpu=False):
"""Returns True if the given node_def must run on CPU, otherwise False.
@@ -147,6 +149,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
return nodes_to_keep
+@tf_export("graph_util.extract_sub_graph")
def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
@@ -184,6 +187,7 @@ def extract_sub_graph(graph_def, dest_nodes):
return out
+@tf_export("graph_util.tensor_shape_from_node_def_name")
def tensor_shape_from_node_def_name(graph, input_name):
"""Convenience function to get a shape from a NodeDef's input string."""
# To get a tensor, the name must be in the form <input>:<port>, for example
@@ -198,6 +202,7 @@ def tensor_shape_from_node_def_name(graph, input_name):
return shape
+@tf_export("graph_util.convert_variables_to_constants")
def convert_variables_to_constants(sess,
input_graph_def,
output_node_names,
@@ -270,6 +275,7 @@ def convert_variables_to_constants(sess,
return output_graph_def
+@tf_export("graph_util.remove_training_nodes")
def remove_training_nodes(input_graph, protected_nodes=None):
"""Prunes out nodes that aren't needed for inference.
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index a3dbe43f06..00fff8d040 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.tf_export import tf_export
# TODO(josh11b): SWIG the code from node_def_util instead of duplicating
@@ -369,6 +370,7 @@ def _GatherReturnElements(requested_return_elements, graph, results):
return combined_return_elements
+@tf_export('import_graph_def')
@deprecated_args(None, 'Please file an issue at '
'https://github.com/tensorflow/tensorflow/issues if you depend'
' on this feature.',
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 909e6d4c7b..c997ead829 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -28,8 +28,10 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('load_op_library')
def load_op_library(library_filename):
"""Loads a TensorFlow plugin, containing custom ops and kernels.
@@ -79,6 +81,7 @@ def load_op_library(library_filename):
return module
+@tf_export('load_file_system_library')
def load_file_system_library(library_filename):
"""Loads a TensorFlow plugin, containing file system implementation.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index e7f08a64a6..b107670275 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.tf_export import tf_export
# Temporary global switch determining if we should enable the work-in-progress
@@ -191,6 +192,7 @@ class _TensorLike(object):
pass
+@tf_export("Tensor")
class Tensor(_TensorLike):
"""Represents one of the outputs of an `Operation`.
@@ -285,7 +287,7 @@ class Tensor(_TensorLike):
self._op = op
self._value_index = value_index
self._dtype = dtypes.as_dtype(dtype)
- self._shape = tensor_shape.unknown_shape()
+ self._shape_val = tensor_shape.unknown_shape()
# List of operations that use this Tensor as input. We maintain this list
# to easily navigate a computation graph.
self._consumers = []
@@ -379,7 +381,18 @@ class Tensor(_TensorLike):
graph, self._as_tf_output(), num_dims, status)
dim_list = [None if i == -1 else i for i in dim_list]
return tensor_shape.TensorShape(dim_list)
- return self._shape
+ return self._shape_val
+
+ @property
+ def _shape(self):
+ logging.warning("Tensor._shape is private, use Tensor.shape "
+ "instead. Tensor._shape will eventually be removed.")
+ return self.shape
+
+ @_shape.setter
+ def _shape(self, value):
+ raise ValueError(
+ "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
def __iter__(self):
if context.in_graph_mode():
@@ -454,7 +467,7 @@ class Tensor(_TensorLike):
this tensor.
"""
if not _USE_C_API:
- self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access
+ self._shape_val = self._shape_val.merge_with(shape)
return
if not isinstance(shape, tensor_shape.TensorShape):
shape = tensor_shape.TensorShape(shape)
@@ -468,13 +481,17 @@ class Tensor(_TensorLike):
dim_list.append(-1)
else:
dim_list.append(dim.value)
- with errors.raise_exception_on_not_ok_status() as status:
- c_api.TF_GraphSetTensorShape_wrapper(
- self._op._graph._c_graph, # pylint: disable=protected-access
- self._as_tf_output(),
- dim_list,
- unknown_shape,
- status)
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_GraphSetTensorShape_wrapper(
+ self._op._graph._c_graph, # pylint: disable=protected-access
+ self._as_tf_output(),
+ dim_list,
+ unknown_shape,
+ status)
+ except errors.InvalidArgumentError as e:
+ # Convert to ValueError for backwards compatibility.
+ raise ValueError(str(e))
@property
def value_index(self):
@@ -775,6 +792,11 @@ class _EagerTensorBase(Tensor):
"""The shape of the tensor as a list."""
return list(self._shape_tuple())
+ @property
+ def ndim(self):
+ """Returns the number of Tensor dimensions."""
+ return self.shape.ndims
+
def cpu(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.context(), "CPU:0")
@@ -866,6 +888,7 @@ _tensor_conversion_func_lock = threading.Lock()
register_dense_tensor_like_type(Tensor)
+@tf_export("convert_to_tensor")
def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
"""Converts the given `value` to a `Tensor`.
@@ -1111,6 +1134,7 @@ def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
as_ref=False)
+@tf_export("convert_to_tensor_or_indexed_slices")
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
@@ -1241,6 +1265,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
# TODO(josh11b): Add ctx argument to conversion_func() signature.
+@tf_export("register_tensor_conversion_function")
def register_tensor_conversion_function(base_type,
conversion_func,
priority=100):
@@ -1301,6 +1326,7 @@ def register_tensor_conversion_function(base_type,
_tensor_conversion_func_cache = {}
+@tf_export("IndexedSlices")
class IndexedSlices(_TensorLike):
"""A sparse representation of a set of tensor slices at given indices.
@@ -1481,6 +1507,7 @@ def _create_c_op(graph, node_def, inputs, control_inputs):
return c_op
+@tf_export("Operation")
class Operation(object):
"""Represents a graph node that performs computation on tensors.
@@ -1560,7 +1587,6 @@ class Operation(object):
"Cannot create a tensor proto whose content is larger than 2GB.")
if not _VALID_OP_NAME_REGEX.match(node_def.name):
raise ValueError("'%s' is not a valid node name" % node_def.name)
- self._node_def = copy.deepcopy(node_def)
c_op = None
elif type(node_def).__name__ == "SwigPyObject":
assert inputs is None
@@ -1569,7 +1595,6 @@ class Operation(object):
assert input_types is None
assert original_op is None
assert op_def is None
- self._node_def = None
c_op = node_def
else:
raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
@@ -1577,28 +1602,27 @@ class Operation(object):
if not isinstance(g, Graph):
raise TypeError("g needs to be a Graph: %s" % g)
self._graph = g
+
if inputs is None:
inputs = []
elif not isinstance(inputs, list):
raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
- self._inputs = list(inputs) # Defensive copy.
- for a in self._inputs:
+ for a in inputs:
if not isinstance(a, Tensor):
raise TypeError("input needs to be a Tensor: %s" % a)
if input_types is None:
- input_types = [i.dtype.base_dtype for i in self._inputs]
+ input_types = [i.dtype.base_dtype for i in inputs]
else:
if not all(
x.is_compatible_with(i.dtype)
- for i, x in zip(self._inputs, input_types)):
+ for i, x in zip(inputs, input_types)):
raise TypeError("In op '%s', input types (%s) are not compatible "
"with expected types (%s)" %
- (self.node_def.name, [i.dtype for i in self._inputs],
+ (self.node_def.name, [i.dtype for i in inputs],
input_types))
- self._input_types_val = input_types
# Build the list of control inputs.
- self._control_inputs = []
+ control_input_ops = []
if control_inputs:
for c in control_inputs:
control_op = None
@@ -1609,11 +1633,20 @@ class Operation(object):
else:
raise TypeError("Control input must be an Operation, "
"a Tensor, or IndexedSlices: %s" % c)
- self._control_inputs.append(control_op)
+ control_input_ops.append(control_op)
+
+ # Don't set private fields with C API enabled to catch users who need to
+ # switch to public API.
+ # TODO(skyewm): delete these fields once we remove _USE_C_API
+ if not self._graph._c_graph:
+ self._inputs_val = list(inputs) # Defensive copy.
+ self._input_types_val = input_types
+ self._control_inputs_val = control_input_ops
+ self._node_def_val = copy.deepcopy(node_def)
+ self._op_def_val = op_def
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._original_op = original_op
- self._op_def = op_def
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access
@@ -1629,15 +1662,15 @@ class Operation(object):
# Refactor so we don't have to do this here.
grouped_inputs = self._reconstruct_sequence_inputs(
op_def, inputs, node_def.attr)
- self._c_op = _create_c_op(self._graph, self._node_def, grouped_inputs,
- self._control_inputs)
+ self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
+ control_input_ops)
else:
self._c_op = None
# Mark that we consume the inputs. This is unnecessary and unsupported with
# the C API enabled, since the C API tracks the tensor consumers instead.
if not self._c_op:
- for input_tensor in self._inputs:
+ for input_tensor in self._inputs_val:
input_tensor._add_consumer(self) # pylint: disable=protected-access
# Initialize self._outputs.
@@ -1752,7 +1785,7 @@ class Operation(object):
if self._c_op:
return c_api.TF_OperationName(self._c_op)
else:
- return self._node_def.name
+ return self._node_def_val.name
@property
def _id(self):
@@ -1771,7 +1804,7 @@ class Operation(object):
if self._c_op:
return c_api.TF_OperationDevice(self._c_op)
else:
- return self._node_def.device
+ return self._node_def_val.device
@property
def _output_types(self):
@@ -1831,7 +1864,7 @@ class Operation(object):
self._c_op, # pylint: disable=protected-access
compat.as_str(_device_string(device)))
else:
- self._node_def.device = _device_string(device)
+ self._node_def_val.device = _device_string(device)
def _add_input(self, tensor, dtype=None):
"""Add a new input to this operation.
@@ -1859,7 +1892,7 @@ class Operation(object):
raise TypeError(
"Cannot convert a tensor of type %s to an input of type %s" %
(tensor.dtype.name, dtype.name))
- self._inputs.append(tensor)
+ self._inputs_val.append(tensor)
self._input_types_val.append(dtype)
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
@@ -1889,8 +1922,8 @@ class Operation(object):
self._tf_input(index),
status)
else:
- self._inputs[index].consumers().remove(self)
- self._inputs[index] = tensor
+ self._inputs_val[index].consumers().remove(self)
+ self._inputs_val[index] = tensor
self._input_types_val[index] = tensor.dtype
tensor._add_consumer(self) # pylint: disable=protected-access
self._recompute_node_def()
@@ -1916,7 +1949,7 @@ class Operation(object):
if not isinstance(op, Operation):
raise TypeError("op must be an Operation: %s" % op)
_assert_same_graph(self, op)
- self._control_inputs.append(op)
+ self._control_inputs_val.append(op)
self._recompute_node_def()
def _add_control_input(self, op):
@@ -1948,13 +1981,14 @@ class Operation(object):
# TODO(skyewm): remove this function when we switch to C API
if self._c_op: return
- del self._node_def.input[:]
+ del self._node_def_val.input[:]
# pylint: disable=protected-access
- self._node_def.input.extend([t._as_node_def_input() for t in self._inputs])
+ self._node_def_val.input.extend(
+ [t._as_node_def_input() for t in self._inputs_val])
# pylint: enable=protected-access
- if self._control_inputs:
- self._node_def.input.extend(
- ["^%s" % op.name for op in self._control_inputs])
+ if self._control_inputs_val:
+ self._node_def_val.input.extend(
+ ["^%s" % op.name for op in self._control_inputs_val])
def __str__(self):
return str(self.node_def)
@@ -2004,7 +2038,17 @@ class Operation(object):
]
# pylint: enable=protected-access
return Operation._InputList(retval)
- return Operation._InputList(self._inputs)
+ return Operation._InputList(self._inputs_val)
+
+ @property
+ def _inputs(self):
+ logging.warning("Operation._inputs is private, use Operation.inputs "
+ "instead. Operation._inputs will eventually be removed.")
+ return self.inputs
+
+ @_inputs.setter
+ def _inputs(self, value):
+ raise ValueError("Cannot assign _inputs")
@property
def _input_types(self):
@@ -2018,6 +2062,10 @@ class Operation(object):
else:
return self._input_types_val
+ @_input_types.setter
+ def _input_types(self, value):
+ raise ValueError("Cannot assign _input_types")
+
@property
def control_inputs(self):
"""The `Operation` objects on which this op has a control dependency.
@@ -2041,7 +2089,22 @@ class Operation(object):
]
# pylint: enable=protected-access
else:
- return self._control_inputs
+ return self._control_inputs_val
+
+ @property
+ def _control_inputs(self):
+ logging.warning("Operation._control_inputs is private, use "
+ "Operation.control_inputs instead. "
+ "Operation._control_inputs will eventually be removed.")
+ return self.control_inputs
+
+ @_control_inputs.setter
+ def _control_inputs(self, value):
+ logging.warning("Operation._control_inputs is private, use "
+ "Operation.control_inputs instead. "
+ "Operation._control_inputs will eventually be removed.")
+ self._remove_all_control_inputs()
+ self._add_control_inputs(value)
@property
def type(self):
@@ -2050,7 +2113,7 @@ class Operation(object):
op_type = c_api.TF_OperationOpType(self._c_op)
return op_type
else:
- return self._node_def.op
+ return self._node_def_val.op
@property
def graph(self):
@@ -2077,7 +2140,13 @@ class Operation(object):
node_def.ParseFromString(compat.as_bytes(data))
return node_def
else:
- return self._node_def
+ return self._node_def_val
+
+ @property
+ def _node_def(self):
+ logging.warning("Operation._node_def is private, use Operation.node_def "
+ "instead. Operation._node_def will eventually be removed.")
+ return self.node_def
@property
def op_def(self):
@@ -2102,7 +2171,13 @@ class Operation(object):
op_def.ParseFromString(compat.as_bytes(data))
return op_def
else:
- return self._op_def
+ return self._op_def_val
+
+ @property
+ def _op_def(self):
+ logging.warning("Operation._op_def is private, use Operation.op_def "
+ "instead. Operation._op_def will eventually be removed.")
+ return self.op_def
@property
def traceback(self):
@@ -2134,7 +2209,7 @@ class Operation(object):
finally:
c_api.TF_DeleteBuffer(buf)
else:
- self._node_def.attr[attr_name].CopyFrom(attr_value)
+ self._node_def_val.attr[attr_name].CopyFrom(attr_value)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@@ -2161,10 +2236,10 @@ class Operation(object):
x = attr_value_pb2.AttrValue()
x.ParseFromString(data)
else:
- if name not in self._node_def.attr:
+ if name not in self._node_def_val.attr:
raise ValueError(
- "No attr named '" + name + "' in " + str(self._node_def))
- x = self._node_def.attr[name]
+ "No attr named '" + name + "' in " + str(self._node_def_val))
+ x = self._node_def_val.attr[name]
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
@@ -2208,6 +2283,7 @@ class Operation(object):
_gradient_registry = registry.Registry("gradient")
+@tf_export("RegisterGradient")
class RegisterGradient(object):
"""A decorator for registering the gradient function for an op type.
@@ -2250,6 +2326,7 @@ class RegisterGradient(object):
return f
+@tf_export("NoGradient", "NotDifferentiable")
def NotDifferentiable(op_type):
"""Specifies that ops of type `op_type` is not differentiable.
@@ -2569,6 +2646,7 @@ def _name_from_scope_name(name):
return name[:-1] if (name and name[-1] == "/") else name
+@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2695,6 +2773,7 @@ class Graph(object):
self._scoped_c_graph = c_api_util.ScopedTFGraph()
else:
self._scoped_c_graph = None
+ self._variable_creator_stack = []
# TODO(apassos) remove once the C API is used by default.
def _use_c_api_hack(self):
@@ -2731,6 +2810,22 @@ class Graph(object):
ret.append((filename, lineno, name, line))
return ret
+ # Note: this method is private because the API of tf.Graph() is public and
+ # frozen, and this functionality is still not ready for public visibility.
+ @tf_contextlib.contextmanager
+ def _variable_creator_scope(self, creator):
+ old = list(self._variable_creator_stack)
+ self._variable_creator_stack.append(creator)
+ try:
+ yield
+ finally:
+ self._variable_creator_stack = old
+
+ # Note: this method is private because the API of tf.Graph() is public and
+ # frozen, and this functionality is still not ready for public visibility.
+ def _get_variable_creator_stack(self):
+ return list(self._variable_creator_stack)
+
def _extract_stack(self):
"""A lightweight, extensible re-implementation of traceback.extract_stack.
@@ -4164,10 +4259,10 @@ class Graph(object):
"""
self._graph = graph
if control_inputs is None:
- self._control_inputs = []
+ self._control_inputs_val = []
self._new_stack = True
else:
- self._control_inputs = control_inputs
+ self._control_inputs_val = control_inputs
self._new_stack = False
self._seen_nodes = set()
self._old_stack = None
@@ -4195,7 +4290,7 @@ class Graph(object):
@property
def control_inputs(self):
- return self._control_inputs
+ return self._control_inputs_val
def add_op(self, op):
self._seen_nodes.add(op)
@@ -4569,6 +4664,9 @@ class Graph(object):
# TODO(agarwal): currently device directives in an outer eager scope will not
# apply to inner graph mode code. Fix that.
+
+
+@tf_export("device")
def device(device_name_or_function):
"""Wrapper for `Graph.device()` using the default graph.
@@ -4598,6 +4696,7 @@ def device(device_name_or_function):
return context.device(device_name_or_function)
+@tf_export("container")
def container(container_name):
"""Wrapper for `Graph.container()` using the default graph.
@@ -4611,6 +4710,7 @@ def container(container_name):
return get_default_graph().container(container_name)
+@tf_export("colocate_with")
def colocate_with(op, ignore_existing=False):
if context.in_graph_mode():
return get_default_graph().colocate_with(op, ignore_existing)
@@ -4621,6 +4721,7 @@ def colocate_with(op, ignore_existing=False):
return _NullContextmanager()
+@tf_export("control_dependencies")
def control_dependencies(control_inputs):
"""Wrapper for `Graph.control_dependencies()` using the default graph.
@@ -4738,6 +4839,7 @@ def default_session(session):
return _default_session_stack.get_controller(session)
+@tf_export("get_default_session")
def get_default_session():
"""Returns the default session for the current thread.
@@ -4950,6 +5052,8 @@ def enable_eager_execution(config=None, device_policy=None):
right device but raises a warning.
tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
hide performance problems.
+ tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
+ raising errors on the other ones.
Raises:
ValueError: If trying to create a context after using graph operations
@@ -4961,10 +5065,10 @@ def enable_eager_execution(config=None, device_policy=None):
"config must be a tf.ConfigProto, but got %s" % type(config))
if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
context.DEVICE_PLACEMENT_WARN,
- context.DEVICE_PLACEMENT_SILENT):
+ context.DEVICE_PLACEMENT_SILENT,
+ context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
raise ValueError(
- "device_policy must be one of None, tfe.DEVICE_PLACEMENT_EXPLICIT, "
- "tfe.DEVICE_PLACEMENT_WARN, tfe.DEVICE_PLACEMENT_SILENT"
+ "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*"
)
# pylint: disable=protected-access
if context._default_mode == context.GRAPH_MODE:
@@ -5027,6 +5131,7 @@ def eager_run(main=None, argv=None):
app.run(main, argv)
+@tf_export("reset_default_graph")
def reset_default_graph():
"""Clears the default graph stack and resets the global default graph.
@@ -5045,6 +5150,7 @@ def reset_default_graph():
_default_graph_stack.reset()
+@tf_export("get_default_graph")
def get_default_graph():
"""Returns the default graph for the current thread.
@@ -5165,6 +5271,7 @@ def _get_graph_from_inputs(op_input_list, graph=None):
return graph or get_default_graph()
+@tf_export("GraphKeys")
class GraphKeys(object):
"""Standard names to use for graph collections.
@@ -5313,6 +5420,7 @@ class GraphKeys(object):
return cls.GLOBAL_VARIABLES
+@tf_export("add_to_collection")
def add_to_collection(name, value):
"""Wrapper for `Graph.add_to_collection()` using the default graph.
@@ -5349,6 +5457,7 @@ def add_to_collections(names, value):
get_default_graph().add_to_collections(names, value)
+@tf_export("get_collection_ref")
def get_collection_ref(key):
"""Wrapper for `Graph.get_collection_ref()` using the default graph.
@@ -5372,6 +5481,7 @@ def get_collection_ref(key):
return get_default_graph().get_collection_ref(key)
+@tf_export("get_collection")
def get_collection(key, scope=None):
"""Wrapper for `Graph.get_collection()` using the default graph.
@@ -5408,6 +5518,7 @@ def get_all_collection_keys():
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
+@tf_export("name_scope", "keras.backend.name_scope")
class name_scope(object): # pylint: disable=invalid-name
"""A context manager for use when defining a Python op.
@@ -5554,6 +5665,7 @@ def prepend_name_scope(name, import_scope):
# pylint: disable=g-doc-return-or-yield
# pylint: disable=not-context-manager
+@tf_export("op_scope")
@tf_contextlib.contextmanager
def op_scope(values, name, default_name=None):
"""DEPRECATED. Same as name_scope above, just different argument order."""
diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h
index 6b53825a6d..d09b36a3e8 100644
--- a/tensorflow/python/framework/python_op_gen_internal.h
+++ b/tensorflow/python/framework/python_op_gen_internal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
#include <unordered_map>
@@ -112,4 +112,4 @@ class GenPythonOp {
} // namespace python_op_gen_internal
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
+#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_OP_GEN_INTERNAL_H_
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 5f1130570d..1e74a790a3 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
DEFAULT_GRAPH_SEED = 87654321
@@ -32,6 +33,7 @@ def _truncate_seed(seed):
return seed % _MAXINT32 # Truncate to fit into 32-bit integer
+@tf_export('get_seed')
def get_seed(op_seed):
"""Returns the local seeds an operation should use given an op-specific seed.
@@ -78,6 +80,7 @@ def get_seed(op_seed):
return seeds
+@tf_export('set_random_seed')
def set_random_seed(seed):
"""Sets the graph-level random seed.
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 6218cc34ca..1fe81e5f17 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -23,6 +23,7 @@ import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
_TensorLike = ops._TensorLike
@@ -31,6 +32,7 @@ _override_helper = ops._override_helper
# pylint: enable=protected-access
+@tf_export("SparseTensor")
class SparseTensor(_TensorLike):
"""Represents a sparse tensor.
@@ -222,8 +224,10 @@ class SparseTensor(_TensorLike):
SparseTensorValue = collections.namedtuple(
"SparseTensorValue", ["indices", "values", "dense_shape"])
+tf_export("SparseTensorValue")(SparseTensorValue)
+@tf_export("convert_to_tensor_or_sparse_tensor")
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
"""Converts value to a `SparseTensor` or `Tensor`.
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 54ec15ea66..222071cb9e 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -19,8 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("Dimension")
class Dimension(object):
"""Represents the value of one dimension in a TensorShape."""
@@ -397,6 +399,7 @@ def as_dimension(value):
return Dimension(value)
+@tf_export("TensorShape")
class TensorShape(object):
"""Represents the shape of a `Tensor`.
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 1b90c7ad4d..d2b8e80305 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -38,6 +38,7 @@ except ImportError:
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=g-import-not-at-top
@@ -328,6 +329,7 @@ def _AssertCompatible(values, dtype):
(dtype.name, repr(mismatch), type(mismatch).__name__))
+@tf_export("make_tensor_proto")
def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
"""Create a TensorProto.
@@ -515,6 +517,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
return tensor_proto
+@tf_export("make_ndarray")
def MakeNdarray(tensor):
"""Create a numpy ndarray from a tensor.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 729c939870..6a7e1d0c89 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -53,6 +53,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import versions
@@ -65,8 +66,10 @@ from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.protobuf import compare
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("test.gpu_device_name")
def gpu_device_name():
"""Returns the name of a GPU device if available or the empty string."""
for x in device_lib.list_local_devices():
@@ -101,6 +104,7 @@ def assert_ops_in_graph(expected_ops, graph):
return actual_ops
+@tf_export("test.assert_equal_graph_def")
def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
"""Asserts that two `GraphDef`s are (mostly) the same.
@@ -630,6 +634,7 @@ def run_in_graph_and_eager_modes(
return decorator
+@tf_export("test.is_gpu_available")
def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
"""Returns whether TensorFlow can access a GPU.
@@ -678,6 +683,7 @@ def device(use_gpu):
yield
+@tf_export("test.TestCase")
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
"""
@@ -1125,43 +1131,90 @@ class TensorFlowTestCase(googletest.TestCase):
print("not close dif = ", np.abs(x - y))
print("not close tol = ", atol + rtol * np.abs(y))
print("dtype = %s, shape = %s" % (a.dtype, a.shape))
- np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, err_msg=msg)
+ # TODO(xpan): There seems to be a bug:
+ # tensorflow/compiler/tests:binary_ops_test pass with float32
+ # nan even though the equal_nan is False by default internally.
+ np.testing.assert_allclose(
+ a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
- def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
- """Asserts that two numpy arrays, or dicts of same, have near values.
-
- This does not support nested dicts. `a` and `b` can be namedtuples too,
- which are converted to dicts.
-
- Args:
- a: The expected numpy ndarray (or anything can be converted to one), or
- dict of same. Must be a dict iff `b` is a dict.
- b: The actual numpy ndarray (or anything can be converted to one), or
- dict of same. Must be a dict iff `a` is a dict.
- rtol: relative tolerance.
- atol: absolute tolerance.
+ def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None):
+ path = path or []
+ path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
- Raises:
- ValueError: if only one of `a` and `b` is a dict.
- """
# Check if a and/or b are namedtuples.
if hasattr(a, "_asdict"):
a = a._asdict()
if hasattr(b, "_asdict"):
b = b._asdict()
- is_a_dict = isinstance(a, dict)
- if is_a_dict != isinstance(b, dict):
- raise ValueError("Can't compare dict to non-dict, %s vs %s." % (a, b))
- if is_a_dict:
+ a_is_dict = isinstance(a, dict)
+ if a_is_dict != isinstance(b, dict):
+ raise ValueError("Can't compare dict to non-dict, a%s vs b%s." %
+ (path_str, path_str))
+ if a_is_dict:
self.assertItemsEqual(
- a.keys(), b.keys(),
- msg="mismatched keys, expected %s, got %s" % (a.keys(), b.keys()))
+ a.keys(),
+ b.keys(),
+ msg="mismatched keys: a%s has keys %s, but b%s has keys %s" %
+ (path_str, a.keys(), path_str, b.keys()))
for k in a:
+ path.append(k)
+ self._assertAllCloseRecursive(
+ a[k], b[k], rtol=rtol, atol=atol, path=path)
+ del path[-1]
+ elif isinstance(a, (list, tuple)):
+ # Try to directly compare a, b as ndarrays; if not work, then traverse
+ # through the sequence, which is more expensive.
+ try:
+ a_as_ndarray = np.array(a)
+ b_as_ndarray = np.array(b)
self._assertArrayLikeAllClose(
- a[k], b[k], rtol=rtol, atol=atol,
- msg="%s: expected %s, got %s." % (k, a, b))
+ a_as_ndarray,
+ b_as_ndarray,
+ rtol=rtol,
+ atol=atol,
+ msg="Mismatched value: a%s is different from b%s." % (path_str,
+ path_str))
+ except (ValueError, TypeError) as e:
+ if len(a) != len(b):
+ raise ValueError(
+ "Mismatched length: a%s has %d items, but b%s has %d items" %
+ (path_str, len(a), path_str, len(b)))
+ for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
+ path.append(str(idx))
+ self._assertAllCloseRecursive(
+ a_ele, b_ele, rtol=rtol, atol=atol, path=path)
+ del path[-1]
+ # a and b are ndarray like objects
else:
- self._assertArrayLikeAllClose(a, b, rtol=rtol, atol=atol)
+ self._assertArrayLikeAllClose(
+ a,
+ b,
+ rtol=rtol,
+ atol=atol,
+ msg="Mismatched value: a%s is different from b%s." % (path_str,
+ path_str))
+
+ def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6):
+ """Asserts that two structures of numpy arrays, have near values.
+
+ `a` and `b` can be arbitrarily nested structures. A layer of a nested
+ structure can be a `dict`, `namedtuple`, `tuple` or `list`.
+
+ Args:
+ a: The expected numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray`, or any arbitrarily nested of structure of these.
+ b: The actual numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray`, or any arbitrarily nested of structure of these.
+ rtol: relative tolerance.
+ atol: absolute tolerance.
+
+ Raises:
+ ValueError: if only one of `a[p]` and `b[p]` is a dict or
+ `a[p]` and `b[p]` have different length, where `[p]` denotes a path
+ to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
+ `[p] = [1]['d']`, then `a[p] = (6, 7)`.
+ """
+ self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol)
def assertAllCloseAccordingToType(self,
a,
@@ -1326,6 +1379,7 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: enable=invalid-name
+@tf_export("test.create_local_cluster")
def create_local_cluster(num_workers, num_ps, protocol="grpc",
worker_config=None, ps_config=None):
"""Create and start local servers and return the associated `Server` objects.
@@ -1407,3 +1461,14 @@ def get_node_def_from_graph(node_name, graph_def):
if node_def.name == node_name:
return node_def
return None
+
+
+def set_producer_version(graph, producer_version):
+ """Sets graph.graph_def_versions.producer to `producer_version`."""
+ # The C API doesn't expose altering GraphDefVersions. We can indirectly set
+ # it via import_graph_def though.
+ graph_def = graph_pb2.GraphDef()
+ graph_def.versions.producer = producer_version
+ with graph.as_default():
+ importer.import_graph_def(graph_def)
+ assert graph.graph_def_versions.producer, producer_version
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 6ddb3533e5..3594d125bf 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import copy
import random
import threading
@@ -252,12 +253,30 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
- def testAllCloseNestedDicts(self):
- a = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
- with self.assertRaisesRegexp(
- TypeError,
- r"inputs could not be safely coerced to any supported types"):
- self.assertAllClose(a, a)
+ def testAllCloseListOfNamedtuples(self):
+ my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
+ l1 = [
+ my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])),
+ my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]]))
+ ]
+ l2 = [
+ ([[2.3, 2.5]], [[0.97, 0.96]]),
+ ([[3.3, 3.5]], [[0.98, 0.99]]),
+ ]
+ self.assertAllClose(l1, l2)
+
+ def testAllCloseNestedStructure(self):
+ a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
+ self.assertAllClose(a, a)
+
+ b = copy.deepcopy(a)
+ self.assertAllClose(a, b)
+
+ # Test mismatched values
+ b["y"][1][0]["nested"]["n"] = 4.2
+ with self.assertRaisesRegexp(AssertionError,
+ r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
+ self.assertAllClose(a, b)
def testArrayNear(self):
a = [1, 2]
@@ -282,6 +301,9 @@ class TestUtilTest(test_util.TensorFlowTestCase):
control_flow_ops.Assert(x, y).run()
def testAssertAllCloseAccordingToType(self):
+ # test plain int
+ self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
+
# test float64
self.assertAllCloseAccordingToType(
np.asarray([1e-8], dtype=np.float64),
diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py
index f03b81eb28..bdcbc15af6 100644
--- a/tensorflow/python/framework/versions.py
+++ b/tensorflow/python/framework/versions.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.util.tf_export import tf_export
__version__ = pywrap_tensorflow.__version__
__git_version__ = pywrap_tensorflow.__git_version__
@@ -28,16 +29,24 @@ __cxx11_abi_flag__ = pywrap_tensorflow.__cxx11_abi_flag__
__monolithic_build__ = pywrap_tensorflow.__monolithic_build__
VERSION = __version__
+tf_export("VERSION").export_constant(__name__, "VERSION")
GIT_VERSION = __git_version__
+tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION")
COMPILER_VERSION = __compiler_version__
+tf_export("COMPILER_VERSION").export_constant(__name__, "COMPILER_VERSION")
CXX11_ABI_FLAG = __cxx11_abi_flag__
MONOLITHIC_BUILD = __monolithic_build__
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
+tf_export("GRAPH_DEF_VERSION").export_constant(__name__, "GRAPH_DEF_VERSION")
GRAPH_DEF_VERSION_MIN_CONSUMER = (
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_CONSUMER)
+tf_export("GRAPH_DEF_VERSION_MIN_CONSUMER").export_constant(
+ __name__, "GRAPH_DEF_VERSION_MIN_CONSUMER")
GRAPH_DEF_VERSION_MIN_PRODUCER = (
pywrap_tensorflow.GRAPH_DEF_VERSION_MIN_PRODUCER)
+tf_export("GRAPH_DEF_VERSION_MIN_PRODUCER").export_constant(
+ __name__, "GRAPH_DEF_VERSION_MIN_PRODUCER")
__all__ = [
"__version__",
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 146bb4311c..61dc4e2afb 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -23,18 +23,33 @@ import sys
from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
+from tensorflow.python.training import saver
def main(_):
- with gfile.GFile(FLAGS.input) as input_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- metagraph.ParseFromString(input_file.read())
+ if FLAGS.metagraphdef:
+ with gfile.GFile(FLAGS.metagraphdef) as meta_file:
+ metagraph = meta_graph_pb2.MetaGraphDef()
+ metagraph.ParseFromString(meta_file.read())
+ else:
+ with gfile.GFile(FLAGS.graphdef) as graph_file:
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_file.read())
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ fetch = graph.get_operation_by_name(FLAGS.fetch)
+ graph.add_to_collection("train_op", fetch)
+ metagraph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
@@ -49,7 +64,25 @@ def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--input", type=str, default=None, help="Input .meta file path.")
+ "--metagraphdef",
+ type=str,
+ default=None,
+ help="Input .meta MetaGraphDef file path.")
+ parser.add_argument(
+ "--graphdef",
+ type=str,
+ default=None,
+ help="Input .pb GraphDef file path.")
+ # Consider making flag fetch work together with flag metagraphdef. As some
+ # MetaGraphDef files don't have collection train_op.
+ parser.add_argument(
+ "--fetch",
+ type=str,
+ default=None,
+ help=
+ "The name of the fetch node. This flag is ignored if flag "
+ "metagraphdef is used."
+ )
parser.add_argument(
"--rewriter_config",
type=str,
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 25c5ef6b68..578f86ca5a 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -376,7 +376,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('Pad-0-0', nodes)
- self.assertIn('Pad-PaddingsConst-LayoutOptimizer', nodes)
+ self.assertIn('Pad-1-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testReduceSum(self):
@@ -587,7 +587,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('concat-0-0', nodes)
- self.assertIn('concat-Const_2-LayoutOptimizer', nodes)
+ self.assertIn('concat-2-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testFill(self):
@@ -698,7 +698,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('ReverseV2-0-0', nodes)
- self.assertIn('ReverseV2-DimsConst-LayoutOptimizer', nodes)
+ self.assertIn('ReverseV2-1-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testReverseWithNonConstDims(self):
@@ -867,7 +867,7 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('MaxPoolV2-0-0', nodes)
self._assert_vec_nhwc_to_nchw('MaxPoolV2-2', nodes)
- self.assertIn('MaxPoolV2-Const_2-LayoutOptimizer', nodes)
+ self.assertIn('MaxPoolV2-1-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testMaxPoolGradV2(self):
@@ -904,7 +904,7 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('MaxPoolGradV2-0-0', nodes)
self._assert_vec_nhwc_to_nchw('MaxPoolGradV2-4', nodes)
- self.assertIn('MaxPoolGradV2-Const_2-LayoutOptimizer', nodes)
+ self.assertIn('MaxPoolGradV2-3-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testSliceWithNonConstAxis(self):
@@ -977,16 +977,17 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('StridedSlice-0-0', nodes)
self._assert_vec_nhwc_to_nchw('StridedSlice-2', nodes)
- self.assertIn('StridedSlice-StridedSlice/begin-LayoutOptimizer', nodes)
- self.assertIn('StridedSlice-StridedSlice/strides-LayoutOptimizer', nodes)
+ self.assertIn('StridedSlice-1-LayoutOptimizer', nodes)
+ self.assertIn('StridedSlice-3-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
- def testStridedSliceWithMask(self):
+ def testStridedSliceWithMask1011(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
- # This will generate a StridedSlice op with begin mask and end mask.
+ # This will generate a StridedSlice op with begin mask and
+ # end mask 11(1011).
s = conv[:, :, 1:-1, :]
output = array_ops.identity(s)
@@ -1010,11 +1011,44 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('strided_slice-0-0', nodes)
- self.assertIn('strided_slice-strided_slice/stack-LayoutOptimizer', nodes)
- self.assertIn('strided_slice-strided_slice/stack_1-LayoutOptimizer',
- nodes)
- self.assertIn('strided_slice-strided_slice/stack_2-LayoutOptimizer',
- nodes)
+ self.assertIn('strided_slice-1-LayoutOptimizer', nodes)
+ self.assertIn('strided_slice-2-LayoutOptimizer', nodes)
+ self.assertIn('strided_slice-3-LayoutOptimizer', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+ def testStridedSliceWithMask0111(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ # This will generate a StridedSlice op with begin mask and
+ # end mask 7(0111).
+ s = conv[:, :, :, 1:-1]
+ output = array_ops.identity(s)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if _is_transpose(node.name):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+ self._assert_trans_nchw_to_nhwc('strided_slice-0-0', nodes)
+ self.assertIn('strided_slice-1-LayoutOptimizer', nodes)
+ self.assertIn('strided_slice-2-LayoutOptimizer', nodes)
+ self.assertIn('strided_slice-3-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testStridedSliceGradWithNonConstAxis(self):
@@ -1055,10 +1089,8 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_nchw_to_nhwc('StridedSliceGrad-0-0', nodes)
self._assert_vec_nhwc_to_nchw('StridedSliceGrad-2', nodes)
- self.assertIn('StridedSlice-StridedSliceGrad/begin-LayoutOptimizer',
- nodes)
- self.assertIn('StridedSlice-StridedSliceGrad/strides-LayoutOptimizer',
- nodes)
+ self.assertIn('StridedSlice-1-LayoutOptimizer', nodes)
+ self.assertIn('StridedSlice-2-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testShapeN(self):
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
index f0dd4483a6..1b657983a4 100644
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -103,6 +103,11 @@ PyObject* TF_OptimizeGraph(
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
+ if (!grappler_item) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
+ return nullptr;
+ }
+
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config);
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 1f20b3ae0e..6125755775 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -14,10 +14,12 @@ py_library(
"_impl/keras/__init__.py",
"_impl/keras/activations.py",
"_impl/keras/applications/__init__.py",
+ "_impl/keras/applications/densenet.py",
"_impl/keras/applications/imagenet_utils.py",
"_impl/keras/applications/inception_resnet_v2.py",
"_impl/keras/applications/inception_v3.py",
"_impl/keras/applications/mobilenet.py",
+ "_impl/keras/applications/nasnet.py",
"_impl/keras/applications/resnet50.py",
"_impl/keras/applications/vgg16.py",
"_impl/keras/applications/vgg19.py",
@@ -76,9 +78,11 @@ py_library(
"_impl/keras/wrappers/scikit_learn.py",
"activations/__init__.py",
"applications/__init__.py",
+ "applications/densenet/__init__.py",
"applications/inception_resnet_v2/__init__.py",
"applications/inception_v3/__init__.py",
"applications/mobilenet/__init__.py",
+ "applications/nasnet/__init__.py",
"applications/resnet50/__init__.py",
"applications/vgg16/__init__.py",
"applications/vgg19/__init__.py",
@@ -257,6 +261,18 @@ py_test(
)
py_test(
+ name = "densenet_test",
+ size = "large",
+ srcs = ["_impl/keras/applications/densenet_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "inception_resnet_v2_test",
size = "medium",
srcs = ["_impl/keras/applications/inception_resnet_v2_test.py"],
@@ -293,6 +309,18 @@ py_test(
)
py_test(
+ name = "nasnet_test",
+ size = "large",
+ srcs = ["_impl/keras/applications/nasnet_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "resnet50_test",
size = "small",
srcs = ["_impl/keras/applications/resnet50_test.py"],
@@ -504,7 +532,7 @@ py_test(
py_test(
name = "recurrent_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/layers/recurrent_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -527,7 +555,7 @@ py_test(
py_test(
name = "wrappers_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/layers/wrappers_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py
index a70250d796..7311353932 100644
--- a/tensorflow/python/keras/_impl/keras/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/__init__.py
@@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.models import Sequential
-__version__ = '2.1.2-tf'
+__version__ = '2.1.3-tf'
diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py
index f017d2ae85..4852b8c36a 100644
--- a/tensorflow/python/keras/_impl/keras/activations.py
+++ b/tensorflow/python/keras/_impl/keras/activations.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras built-in activation functions.
+"""Built-in activation functions.
"""
from __future__ import absolute_import
from __future__ import division
@@ -61,10 +61,12 @@ def selu(x):
x: A tensor or variable to compute the activation function for.
Returns:
- Tensor with the same shape and dtype as `x`.
+ Tensor with the same shape and dtype as `x`.
+
+ # Note
+ - To be used together with the initialization "lecun_normal".
+ - To be used together with the dropout variant "AlphaDropout".
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
diff --git a/tensorflow/python/keras/_impl/keras/applications/__init__.py b/tensorflow/python/keras/_impl/keras/applications/__init__.py
index c11c52b71e..206a769b37 100644
--- a/tensorflow/python/keras/_impl/keras/applications/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/applications/__init__.py
@@ -18,9 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201
from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50
from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16
from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19
diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py
new file mode 100644
index 0000000000..9e40d34930
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py
@@ -0,0 +1,346 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
+"""DenseNet models for Keras.
+
+# Reference paper
+
+- [Densely Connected Convolutional Networks]
+ (https://arxiv.org/abs/1608.06993) (CVPR 2017 Best Paper Award)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications import imagenet_utils
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Concatenate
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+
+
+DENSENET121_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET121_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5'
+DENSENET169_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET169_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5'
+DENSENET201_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET201_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5'
+
+
+def dense_block(x, blocks, name):
+ """A dense block.
+
+ Arguments:
+ x: input tensor.
+ blocks: integer, the number of building blocks.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ for i in range(blocks):
+ x = conv_block(x, 32, name=name + '_block' + str(i + 1))
+ return x
+
+
+def transition_block(x, reduction, name):
+ """A transition block.
+
+ Arguments:
+ x: input tensor.
+ reduction: float, compression rate at transition layers.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(x)
+ x = Activation('relu', name=name + '_relu')(x)
+ x = Conv2D(
+ int(K.int_shape(x)[bn_axis] * reduction),
+ 1,
+ use_bias=False,
+ name=name + '_conv')(
+ x)
+ x = AveragePooling2D(2, strides=2, name=name + '_pool')(x)
+ return x
+
+
+def conv_block(x, growth_rate, name):
+ """A building block for a dense block.
+
+ Arguments:
+ x: input tensor.
+ growth_rate: float, growth rate at dense layers.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+ x1 = BatchNormalization(
+ axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
+ x)
+ x1 = Activation('relu', name=name + '_0_relu')(x1)
+ x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(x1)
+ x1 = BatchNormalization(
+ axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
+ x1)
+ x1 = Activation('relu', name=name + '_1_relu')(x1)
+ x1 = Conv2D(
+ growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
+ x1)
+ x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
+ return x
+
+
+def DenseNet(blocks,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates the DenseNet architecture.
+
+ Optionally loads weights pre-trained
+ on ImageNet. Note that when using TensorFlow,
+ for best performance you should set
+ `image_data_format='channels_last'` in your Keras config
+ at ~/.keras/keras.json.
+
+ The model and the weights are compatible with
+ TensorFlow, Theano, and CNTK. The data format
+ convention used by the model is the one
+ specified in your Keras config file.
+
+ Arguments:
+ blocks: numbers of building blocks for the four dense layers.
+ include_top: whether to include the fully-connected
+ layer at the top of the network.
+ weights: one of `None` (random initialization),
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
+ input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
+ to use as image input for the model.
+ input_shape: optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(224, 224, 3)` (with `channels_last` data format)
+ or `(3, 224, 224)` (with `channels_first` data format).
+ It should have exactly 3 inputs channels.
+ pooling: optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model will be
+ the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: in case of invalid argument for `weights`,
+ or invalid input shape.
+ """
+ if not (weights in {'imagenet', None} or os.path.exists(weights)):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `imagenet` '
+ '(pre-training on ImageNet), '
+ 'or the path to the weights file to be loaded.')
+
+ if weights == 'imagenet' and include_top and classes != 1000:
+ raise ValueError('If using `weights` as imagenet with `include_top`'
+ ' as true, `classes` should be 1000')
+
+ # Determine proper input shape
+ input_shape = _obtain_input_shape(
+ input_shape,
+ default_size=224,
+ min_size=221,
+ data_format=K.image_data_format(),
+ require_flatten=include_top,
+ weights=weights)
+
+ if input_tensor is None:
+ img_input = Input(shape=input_shape)
+ else:
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+
+ x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
+ x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
+ x = Activation('relu', name='conv1/relu')(x)
+ x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
+ x = MaxPooling2D(3, strides=2, name='pool1')(x)
+
+ x = dense_block(x, blocks[0], name='conv2')
+ x = transition_block(x, 0.5, name='pool2')
+ x = dense_block(x, blocks[1], name='conv3')
+ x = transition_block(x, 0.5, name='pool3')
+ x = dense_block(x, blocks[2], name='conv4')
+ x = transition_block(x, 0.5, name='pool4')
+ x = dense_block(x, blocks[3], name='conv5')
+
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
+
+ if include_top:
+ x = GlobalAveragePooling2D(name='avg_pool')(x)
+ x = Dense(classes, activation='softmax', name='fc1000')(x)
+ else:
+ if pooling == 'avg':
+ x = GlobalAveragePooling2D(name='avg_pool')(x)
+ elif pooling == 'max':
+ x = GlobalMaxPooling2D(name='max_pool')(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ # Create model.
+ if blocks == [6, 12, 24, 16]:
+ model = Model(inputs, x, name='densenet121')
+ elif blocks == [6, 12, 32, 32]:
+ model = Model(inputs, x, name='densenet169')
+ elif blocks == [6, 12, 48, 32]:
+ model = Model(inputs, x, name='densenet201')
+ else:
+ model = Model(inputs, x, name='densenet')
+
+ # Load weights.
+ if weights == 'imagenet':
+ if include_top:
+ if blocks == [6, 12, 24, 16]:
+ weights_path = get_file(
+ 'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET121_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='0962ca643bae20f9b6771cb844dca3b0')
+ elif blocks == [6, 12, 32, 32]:
+ weights_path = get_file(
+ 'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET169_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='bcf9965cf5064a5f9eb6d7dc69386f43')
+ elif blocks == [6, 12, 48, 32]:
+ weights_path = get_file(
+ 'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET201_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='7bb75edd58cb43163be7e0005fbe95ef')
+ else:
+ if blocks == [6, 12, 24, 16]:
+ weights_path = get_file(
+ 'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET121_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='4912a53fbd2a69346e7f2c0b5ec8c6d3')
+ elif blocks == [6, 12, 32, 32]:
+ weights_path = get_file(
+ 'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET169_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='50662582284e4cf834ce40ab4dfa58c6')
+ elif blocks == [6, 12, 48, 32]:
+ weights_path = get_file(
+ 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET201_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='1c2de60ee40562448dbac34a0737e798')
+ model.load_weights(weights_path)
+ elif weights is not None:
+ model.load_weights(weights)
+
+ return model
+
+
+def DenseNet121(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def DenseNet169(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def DenseNet201(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def preprocess_input(x, data_format=None):
+ """Preprocesses a numpy array encoding a batch of images.
+
+ Arguments:
+ x: a 3D or 4D numpy array consists of RGB values within [0, 255].
+ data_format: data format of the image tensor.
+
+ Returns:
+ Preprocessed array.
+ """
+ return imagenet_utils.preprocess_input(x, data_format, mode='torch')
+
+
+setattr(DenseNet121, '__doc__', DenseNet.__doc__)
+setattr(DenseNet169, '__doc__', DenseNet.__doc__)
+setattr(DenseNet201, '__doc__', DenseNet.__doc__)
diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet_test.py b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py
new file mode 100644
index 0000000000..3b92287a1e
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DenseNet application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class DenseNet121Test(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet121(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet121(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1024))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet121(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1024))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet121(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet121(weights='imagenet',
+ classes=2000)
+
+
+class DenseNet169Test(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet169(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet169(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1664))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet169(weights=None,
+ include_top=False,
+ pooling='max')
+ self.assertEqual(model.output_shape, (None, 1664))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet169(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet169(weights='imagenet',
+ classes=2000)
+
+
+class DenseNet201(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet201(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet201(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1920))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet201(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1920))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet201(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet201(weights='imagenet',
+ classes=2000)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
index 63ee83cb51..f1f20f12a8 100644
--- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities used by models pre-trained on ImageNet.
+"""Utilities for ImageNet data preprocessing & prediction decoding.
"""
from __future__ import absolute_import
from __future__ import division
@@ -35,63 +35,92 @@ _IMAGENET_MEAN = None
def _preprocess_numpy_input(x, data_format, mode):
- """Preprocesses a image tensor as a Numpy array.
+ """Preprocesses a Numpy array encoding a batch of images.
Arguments:
- x: input Numpy, 3D or 4D.
- data_format: data format of the image tensor.
- mode: One of "caffe", "tf".
+ x: Input array, 3D or 4D.
+ data_format: Data format of the image array.
+ mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
+ - torch: will scale pixels between 0 and 1 and then
+ will normalize each channel with respect to the
+ ImageNet dataset.
Returns:
- Preprocessed array.
+ Preprocessed Numpy array.
"""
if mode == 'tf':
x /= 127.5
x -= 1.
return x
+ if mode == 'torch':
+ x /= 255.
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ else:
+ if data_format == 'channels_first':
+ # 'RGB'->'BGR'
+ if x.ndim == 3:
+ x = x[::-1, ...]
+ else:
+ x = x[:, ::-1, ...]
+ else:
+ # 'RGB'->'BGR'
+ x = x[..., ::-1]
+ mean = [103.939, 116.779, 123.68]
+ std = None
+
+ # Zero-center by mean pixel
if data_format == 'channels_first':
if x.ndim == 3:
- # 'RGB'->'BGR'
- x = x[::-1, ...]
- # Zero-center by mean pixel
- x[0, :, :] -= 103.939
- x[1, :, :] -= 116.779
- x[2, :, :] -= 123.68
+ x[0, :, :] -= mean[0]
+ x[1, :, :] -= mean[1]
+ x[2, :, :] -= mean[2]
+ if std is not None:
+ x[0, :, :] /= std[0]
+ x[1, :, :] /= std[1]
+ x[2, :, :] /= std[2]
else:
- x = x[:, ::-1, ...]
- x[:, 0, :, :] -= 103.939
- x[:, 1, :, :] -= 116.779
- x[:, 2, :, :] -= 123.68
+ x[:, 0, :, :] -= mean[0]
+ x[:, 1, :, :] -= mean[1]
+ x[:, 2, :, :] -= mean[2]
+ if std is not None:
+ x[:, 0, :, :] /= std[0]
+ x[:, 1, :, :] /= std[1]
+ x[:, 2, :, :] /= std[2]
else:
- # 'RGB'->'BGR'
- x = x[..., ::-1]
- # Zero-center by mean pixel
- x[..., 0] -= 103.939
- x[..., 1] -= 116.779
- x[..., 2] -= 123.68
+ x[..., 0] -= mean[0]
+ x[..., 1] -= mean[1]
+ x[..., 2] -= mean[2]
+ if std is not None:
+ x[..., 0] /= std[0]
+ x[..., 1] /= std[1]
+ x[..., 2] /= std[2]
return x
def _preprocess_symbolic_input(x, data_format, mode):
- """Preprocesses a symbolic image tensor.
+ """Preprocesses a tensor encoding a batch of images.
Arguments:
- x: symoblic tensor, 3D or 4D.
- data_format: data format of the image tensor.
- mode: One of "caffe", "tf".
+ x: Input tensor, 3D or 4D.
+ data_format: Data format of the image tensor.
+ mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
+ - torch: will scale pixels between 0 and 1 and then
+ will normalize each channel with respect to the
+ ImageNet dataset.
Returns:
Preprocessed tensor.
@@ -103,32 +132,42 @@ def _preprocess_symbolic_input(x, data_format, mode):
x -= 1.
return x
- if data_format == 'channels_first':
- # 'RGB'->'BGR'
- if K.ndim(x) == 3:
- x = x[::-1, ...]
- else:
- x = x[:, ::-1, ...]
+ if mode == 'torch':
+ x /= 255.
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
else:
- # 'RGB'->'BGR'
- x = x[..., ::-1]
+ if data_format == 'channels_first':
+ # 'RGB'->'BGR'
+ if K.ndim(x) == 3:
+ x = x[::-1, ...]
+ else:
+ x = x[:, ::-1, ...]
+ else:
+ # 'RGB'->'BGR'
+ x = x[..., ::-1]
+ mean = [103.939, 116.779, 123.68]
+ std = None
if _IMAGENET_MEAN is None:
- _IMAGENET_MEAN = K.constant(-np.array([103.939, 116.779, 123.68]))
+ _IMAGENET_MEAN = K.constant(-np.array(mean))
+
# Zero-center by mean pixel
if K.dtype(x) != K.dtype(_IMAGENET_MEAN):
x = K.bias_add(x, K.cast(_IMAGENET_MEAN, K.dtype(x)), data_format)
else:
x = K.bias_add(x, _IMAGENET_MEAN, data_format)
+ if std is not None:
+ x /= std
return x
def preprocess_input(x, data_format=None, mode='caffe'):
- """Preprocesses a tensor encoding a batch of images.
+ """Preprocesses a tensor or Numpy array encoding a batch of images.
Arguments:
- x: input Numpy or symoblic tensor, 3D or 4D.
- data_format: data format of the image tensor.
+ x: Input Numpy or symbolic tensor, 3D or 4D.
+ data_format: Data format of the image tensor/array.
mode: One of "caffe", "tf".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
@@ -138,10 +177,10 @@ def preprocess_input(x, data_format=None, mode='caffe'):
sample-wise.
Returns:
- Preprocessed tensor.
+ Preprocessed tensor or Numpy array.
Raises:
- ValueError: in case of incorrect data_format.
+ ValueError: In case of unknown `data_format` argument.
"""
if data_format is None:
data_format = K.image_data_format()
@@ -159,7 +198,7 @@ def decode_predictions(preds, top=5):
Arguments:
preds: Numpy tensor encoding a batch of predictions.
- top: integer, how many top-guesses to return.
+ top: Integer, how many top-guesses to return.
Returns:
A list of lists of top class prediction tuples
@@ -167,7 +206,7 @@ def decode_predictions(preds, top=5):
One list of tuples per sample in batch input.
Raises:
- ValueError: in case of invalid shape of the `pred` array
+ ValueError: In case of invalid shape of the `pred` array
(must be 2D).
"""
global CLASS_INDEX
@@ -177,10 +216,11 @@ def decode_predictions(preds, top=5):
'(i.e. a 2D array of shape (samples, 1000)). '
'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
- fpath = get_file('imagenet_class_index.json',
- CLASS_INDEX_PATH,
- cache_subdir='models',
- file_hash='c2c37ea517e94d9795004a39431a14cb')
+ fpath = get_file(
+ 'imagenet_class_index.json',
+ CLASS_INDEX_PATH,
+ cache_subdir='models',
+ file_hash='c2c37ea517e94d9795004a39431a14cb')
CLASS_INDEX = json.load(open(fpath))
results = []
for pred in preds:
@@ -197,17 +237,17 @@ def _obtain_input_shape(input_shape,
data_format,
require_flatten,
weights=None):
- """Internal utility to compute/validate an ImageNet model's input shape.
+ """Internal utility to compute/validate a model's input shape.
Arguments:
- input_shape: either None (will return the default network input shape),
+ input_shape: Either None (will return the default network input shape),
or a user-provided shape to be validated.
- default_size: default input width/height for the model.
- min_size: minimum input width/height accepted by the model.
- data_format: image data format to use.
- require_flatten: whether the model is expected to
+ default_size: Default input width/height for the model.
+ min_size: Minimum input width/height accepted by the model.
+ data_format: Image data format to use.
+ require_flatten: Whether the model is expected to
be linked to a classifier via a Flatten layer.
- weights: one of `None` (random initialization)
+ weights: One of `None` (random initialization)
or 'imagenet' (pre-training on ImageNet).
If weights='imagenet' input channels must be equal to 3.
@@ -215,7 +255,7 @@ def _obtain_input_shape(input_shape,
An integer shape tuple (may include None entries).
Raises:
- ValueError: in case of invalid argument values.
+ ValueError: In case of invalid argument values.
"""
if weights != 'imagenet' and input_shape and len(input_shape) == 3:
if data_format == 'channels_first':
@@ -252,8 +292,8 @@ def _obtain_input_shape(input_shape,
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[1] is not None and input_shape[1] < min_size) or
(input_shape[2] is not None and input_shape[2] < min_size)):
- raise ValueError('Input size must be at least ' + str(min_size) + 'x'
- + str(min_size) + '; got '
+ raise ValueError('Input size must be at least ' + str(min_size) +
+ 'x' + str(min_size) + '; got '
'`input_shape=' + str(input_shape) + '`')
else:
if input_shape is not None:
@@ -264,8 +304,8 @@ def _obtain_input_shape(input_shape,
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[0] is not None and input_shape[0] < min_size) or
(input_shape[1] is not None and input_shape[1] < min_size)):
- raise ValueError('Input size must be at least ' + str(min_size) + 'x'
- + str(min_size) + '; got '
+ raise ValueError('Input size must be at least ' + str(min_size) +
+ 'x' + str(min_size) + '; got '
'`input_shape=' + str(input_shape) + '`')
else:
if require_flatten:
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
index 2e73cefb6c..1dc15b5b34 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Inception-ResNet V2 model for Keras.
# Reference
@@ -28,7 +30,7 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -43,6 +45,8 @@ from tensorflow.python.keras._impl.keras.layers import Lambda
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
+
BASE_WEIGHT_URL = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.7/'
@@ -116,7 +120,8 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
scale: scaling factor to scale the residuals (i.e., the output of
passing `x` through an inception module) before adding them
to the shortcut branch. Let `r` be the output from the residual
- branch, the output of this block will be `x + scale * r`.
+ branch,
+ the output of this block will be `x + scale * r`.
block_type: `'block35'`, `'block17'` or `'block8'`, determines
the network structure in the residual branch.
block_idx: an `int` used for generating layer names. The Inception-ResNet
@@ -128,8 +133,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
will have `block_type='block35', block_idx=0`, ane the layer names
will have
a common prefix `'block35_0'`.
- activation: activation function to use at the end of the block
- (see [activations](../activations.md)).
+ activation: activation function to use at the end of the block.
When `activation=None`, no activation is applied
(i.e., "linear" activation: `a(x) = x`).
@@ -178,6 +182,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
x = Lambda(
lambda inputs, scale: inputs[0] + inputs[1] * scale,
+ output_shape=K.int_shape(x)[1:],
arguments={'scale': scale},
name=block_name)([x, up])
if activation is not None:
@@ -185,7 +190,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
return x
-def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name
+def InceptionResNetV2(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
@@ -211,8 +216,8 @@ def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
index 4424b92804..ff57116f2d 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Inception V3 model for Keras.
Note that the input image format for this model is different than for
@@ -35,7 +36,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -48,6 +49,7 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
@@ -92,7 +94,8 @@ def conv2d_bn(x,
strides=strides,
padding=padding,
use_bias=False,
- name=conv_name)(x)
+ name=conv_name)(
+ x)
x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
x = Activation('relu', name=name)(x)
return x
@@ -109,7 +112,7 @@ def InceptionV3(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
TensorFlow and Theano. The data format
@@ -121,15 +124,15 @@ def InceptionV3(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- "imagenet" (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(299, 299, 3)` (with `channels_last` data format)
or `(3, 299, 299)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 139.
E.g. `(150, 150, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -176,7 +179,10 @@ def InceptionV3(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
if K.image_data_format() == 'channels_first':
channel_axis = 1
@@ -389,6 +395,7 @@ def InceptionV3(include_top=True,
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 5f97c138fc..790bf8cead 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""MobileNet v1 models for Keras.
MobileNet is a general architecture and can be used for multiple use cases.
@@ -56,7 +58,7 @@ the 100 % MobileNet on various input sizes:
------------------------------------------------------------------------
The weights for all 16 models are obtained and translated
-from Tensorflow checkpoints found at
+from TensorFlow checkpoints found at
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md
# Reference
@@ -75,9 +77,10 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
@@ -91,6 +94,7 @@ from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
+
BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/'
@@ -130,7 +134,7 @@ class DepthwiseConv2D(Conv2D):
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: one of `"valid"` or `"same"` (case-insensitive).
+ padding: one of `'valid'` or `'same'` (case-insensitive).
depth_multiplier: The number of depthwise convolution output channels
for each input channel.
The total number of depthwise convolution output
@@ -144,29 +148,21 @@ class DepthwiseConv2D(Conv2D):
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be "channels_last".
- activation: Activation function to use
- (see [activations](../activations.md)).
+ If you never set it, then it will be 'channels_last'.
+ activation: Activation function to use.
If you don't specify anything, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ (ie. 'linear' activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
- depthwise_initializer: Initializer for the depthwise kernel matrix
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ depthwise_initializer: Initializer for the depthwise kernel matrix.
+ bias_initializer: Initializer for the bias vector.
depthwise_regularizer: Regularizer function applied to
- the depthwise kernel matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the depthwise kernel matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its 'activation')..
depthwise_constraint: Constraint function applied to
- the depthwise kernel matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the depthwise kernel matrix.
+ bias_constraint: Constraint function applied to the bias vector.
Input shape:
4D tensor with shape:
@@ -216,6 +212,7 @@ class DepthwiseConv2D(Conv2D):
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.bias_initializer = initializers.get(bias_initializer)
+ @shape_type_conversion
def build(self, input_shape):
if len(input_shape) < 4:
raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
@@ -269,6 +266,7 @@ class DepthwiseConv2D(Conv2D):
return outputs
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
@@ -305,7 +303,7 @@ class DepthwiseConv2D(Conv2D):
return config
-def MobileNet(input_shape=None, # pylint: disable=invalid-name
+def MobileNet(input_shape=None,
alpha=1.0,
depth_multiplier=1,
dropout=1e-3,
@@ -334,7 +332,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or (3, 224, 224) (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 32.
E.g. `(200, 200, 3)` would be one valid value.
alpha: controls the width of the network.
@@ -350,8 +348,8 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of
`layers.Input()`)
to use as image input for the model.
@@ -380,6 +378,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
RuntimeError: If attempting to run this model with a
backend that does not support separable convolutions.
"""
+
+ if K.backend() != 'tensorflow':
+ raise RuntimeError('Only TensorFlow backend is currently supported, '
+ 'as other backends do not support '
+ 'depthwise convolution.')
+
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
@@ -390,7 +394,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
raise ValueError('If using `weights` as ImageNet with `include_top` '
'as true, `classes` should be 1000')
- # Determine proper input shape.
+ # Determine proper input shape and default size.
if input_shape is None:
default_size = 224
else:
@@ -400,10 +404,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
else:
rows = input_shape[0]
cols = input_shape[1]
+
if rows == cols and rows in [128, 160, 192, 224]:
default_size = rows
else:
default_size = 224
+
input_shape = _obtain_input_shape(
input_shape,
default_size=default_size,
@@ -411,6 +417,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
data_format=K.image_data_format(),
require_flatten=include_top,
weights=weights)
+
if K.image_data_format() == 'channels_last':
row_axis, col_axis = (0, 1)
else:
@@ -536,8 +543,6 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
if old_data_format:
K.set_image_data_format(old_data_format)
- elif weights is not None:
- model.load_weights(weights)
return model
@@ -595,7 +600,8 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
padding='same',
use_bias=False,
strides=strides,
- name='conv1')(inputs)
+ name='conv1')(
+ inputs)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
@@ -662,7 +668,8 @@ def _depthwise_conv_block(inputs,
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
- name='conv_dw_%d' % block_id)(inputs)
+ name='conv_dw_%d' % block_id)(
+ inputs)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
@@ -671,6 +678,7 @@ def _depthwise_conv_block(inputs,
padding='same',
use_bias=False,
strides=(1, 1),
- name='conv_pw_%d' % block_id)(x)
+ name='conv_pw_%d' % block_id)(
+ x)
x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x)
return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
new file mode 100644
index 0000000000..5dd038c096
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
@@ -0,0 +1,783 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=line-too-long
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
+"""NASNet-A models for Keras.
+
+NASNet refers to Neural Architecture Search Network, a family of models
+that were designed automatically by learning the model architectures
+directly on the dataset of interest.
+
+Here we consider NASNet-A, the highest performance model that was found
+for the CIFAR-10 dataset, and then extended to ImageNet 2012 dataset,
+obtaining state of the art performance on CIFAR-10 and ImageNet 2012.
+Only the NASNet-A models, and their respective weights, which are suited
+for ImageNet 2012 are provided.
+
+The below table describes the performance on ImageNet 2012:
+--------------------------------------------------------------------------------
+ Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M)
+--------------------------------------------------------------------------------
+| NASNet-A (4 @ 1056) | 74.0 % | 91.6 % | 564 M | 5.3 |
+| NASNet-A (6 @ 4032) | 82.7 % | 96.2 % | 23.8 B | 88.9 |
+--------------------------------------------------------------------------------
+
+References:
+ - [Learning Transferable Architectures for Scalable Image Recognition]
+ (https://arxiv.org/abs/1707.07012)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import add
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import concatenate
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Cropping2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import SeparableConv2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
+
+
+NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile.h5'
+NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile-no-top.h5'
+NASNET_LARGE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large.h5'
+NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large-no-top.h5'
+
+
+def NASNet(input_shape=None,
+ penultimate_filters=4032,
+ num_blocks=6,
+ stem_block_filters=96,
+ skip_reduction=True,
+ filter_multiplier=2,
+ include_top=True,
+ weights=None,
+ input_tensor=None,
+ pooling=None,
+ classes=1000,
+ default_size=None):
+ """Instantiates a NASNet model.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(331, 331, 3)` for NASNetLarge or
+ `(224, 224, 3)` for NASNetMobile
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ penultimate_filters: Number of filters in the penultimate layer.
+ NASNet models use the notation `NASNet (N @ P)`, where:
+ - N is the number of blocks
+ - P is the number of penultimate filters
+ num_blocks: Number of repeated blocks of the NASNet model.
+ NASNet models use the notation `NASNet (N @ P)`, where:
+ - N is the number of blocks
+ - P is the number of penultimate filters
+ stem_block_filters: Number of filters in the initial stem block
+ skip_reduction: Whether to skip the reduction step at the tail
+ end of the network. Set to `False` for CIFAR models.
+ filter_multiplier: Controls the width of the network.
+ - If `filter_multiplier` < 1.0, proportionally decreases the number
+ of filters in each layer.
+ - If `filter_multiplier` > 1.0, proportionally increases the number
+ of filters in each layer.
+ - If `filter_multiplier` = 1, default number of filters from the
+ paper are used at each layer.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+ default_size: Specifies the default image size of the model
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: In case of invalid argument for `weights`,
+ invalid input shape or invalid `penultimate_filters` value.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ if K.backend() != 'tensorflow':
+ raise RuntimeError('Only Tensorflow backend is currently supported, '
+ 'as other backends do not support '
+ 'separable convolution.')
+
+ if not (weights in {'imagenet', None} or os.path.exists(weights)):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `imagenet` '
+ '(pre-training on ImageNet), '
+ 'or the path to the weights file to be loaded.')
+
+ if weights == 'imagenet' and include_top and classes != 1000:
+ raise ValueError('If using `weights` as ImageNet with `include_top` '
+ 'as true, `classes` should be 1000')
+
+ if default_size is None:
+ default_size = 331
+
+ # Determine proper input shape and default size.
+ input_shape = _obtain_input_shape(
+ input_shape,
+ default_size=default_size,
+ min_size=32,
+ data_format=K.image_data_format(),
+ require_flatten=include_top or weights,
+ weights=weights)
+
+ if K.image_data_format() != 'channels_last':
+ logging.warning('The NASNet family of models is only available '
+ 'for the input data format "channels_last" '
+ '(width, height, channels). '
+ 'However your settings specify the default '
+ 'data format "channels_first" (channels, width, height).'
+ ' You should set `image_data_format="channels_last"` '
+ 'in your Keras config located at ~/.keras/keras.json. '
+ 'The model being returned right now will expect inputs '
+ 'to follow the "channels_last" data format.')
+ K.set_image_data_format('channels_last')
+ old_data_format = 'channels_first'
+ else:
+ old_data_format = None
+
+ if input_tensor is None:
+ img_input = Input(shape=input_shape)
+ else:
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ if penultimate_filters % 24 != 0:
+ raise ValueError(
+ 'For NASNet-A models, the value of `penultimate_filters` '
+ 'needs to be divisible by 24. Current value: %d' % penultimate_filters)
+
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+ filters = penultimate_filters // 24
+
+ if not skip_reduction:
+ x = Conv2D(
+ stem_block_filters, (3, 3),
+ strides=(2, 2),
+ padding='valid',
+ use_bias=False,
+ name='stem_conv1',
+ kernel_initializer='he_normal')(
+ img_input)
+ else:
+ x = Conv2D(
+ stem_block_filters, (3, 3),
+ strides=(1, 1),
+ padding='same',
+ use_bias=False,
+ name='stem_conv1',
+ kernel_initializer='he_normal')(
+ img_input)
+
+ x = BatchNormalization(
+ axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='stem_bn1')(
+ x)
+
+ p = None
+ if not skip_reduction: # imagenet / mobile mode
+ x, p = _reduction_a_cell(
+ x, p, filters // (filter_multiplier**2), block_id='stem_1')
+ x, p = _reduction_a_cell(
+ x, p, filters // filter_multiplier, block_id='stem_2')
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(x, p, filters, block_id='%d' % (i))
+
+ x, p0 = _reduction_a_cell(
+ x, p, filters * filter_multiplier, block_id='reduce_%d' % (num_blocks))
+
+ p = p0 if not skip_reduction else p
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(
+ x, p, filters * filter_multiplier, block_id='%d' % (num_blocks + i + 1))
+
+ x, p0 = _reduction_a_cell(
+ x,
+ p,
+ filters * filter_multiplier**2,
+ block_id='reduce_%d' % (2 * num_blocks))
+
+ p = p0 if not skip_reduction else p
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(
+ x,
+ p,
+ filters * filter_multiplier**2,
+ block_id='%d' % (2 * num_blocks + i + 1))
+
+ x = Activation('relu')(x)
+
+ if include_top:
+ x = GlobalAveragePooling2D()(x)
+ x = Dense(classes, activation='softmax', name='predictions')(x)
+ else:
+ if pooling == 'avg':
+ x = GlobalAveragePooling2D()(x)
+ elif pooling == 'max':
+ x = GlobalMaxPooling2D()(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ model = Model(inputs, x, name='NASNet')
+
+ # load weights
+ if weights == 'imagenet':
+ if default_size == 224: # mobile version
+ if include_top:
+ weight_path = NASNET_MOBILE_WEIGHT_PATH
+ model_name = 'nasnet_mobile.h5'
+ else:
+ weight_path = NASNET_MOBILE_WEIGHT_PATH_NO_TOP
+ model_name = 'nasnet_mobile_no_top.h5'
+
+ weights_file = get_file(model_name, weight_path, cache_subdir='models')
+ model.load_weights(weights_file)
+
+ elif default_size == 331: # large version
+ if include_top:
+ weight_path = NASNET_LARGE_WEIGHT_PATH
+ model_name = 'nasnet_large.h5'
+ else:
+ weight_path = NASNET_LARGE_WEIGHT_PATH_NO_TOP
+ model_name = 'nasnet_large_no_top.h5'
+
+ weights_file = get_file(model_name, weight_path, cache_subdir='models')
+ model.load_weights(weights_file)
+ else:
+ raise ValueError('ImageNet weights can only be loaded with NASNetLarge'
+ ' or NASNetMobile')
+ elif weights is not None:
+ model.load_weights(weights)
+
+ if old_data_format:
+ K.set_image_data_format(old_data_format)
+
+ return model
+
+
+def NASNetLarge(input_shape=None,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates a NASNet model in ImageNet mode.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(331, 331, 3)` for NASNetLarge.
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: in case of invalid argument for `weights`,
+ or invalid input shape.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ return NASNet(
+ input_shape,
+ penultimate_filters=4032,
+ num_blocks=6,
+ stem_block_filters=96,
+ skip_reduction=False,
+ filter_multiplier=2,
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ pooling=pooling,
+ classes=classes,
+ default_size=331)
+
+
+def NASNetMobile(input_shape=None,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates a Mobile NASNet model in ImageNet mode.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(224, 224, 3)` for NASNetMobile
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: In case of invalid argument for `weights`,
+ or invalid input shape.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ return NASNet(
+ input_shape,
+ penultimate_filters=1056,
+ num_blocks=4,
+ stem_block_filters=32,
+ skip_reduction=False,
+ filter_multiplier=2,
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ pooling=pooling,
+ classes=classes,
+ default_size=224)
+
+
+def _separable_conv_block(ip,
+ filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ block_id=None):
+ """Adds 2 blocks of [relu-separable conv-batchnorm].
+
+ Arguments:
+ ip: Input tensor
+ filters: Number of output filters per layer
+ kernel_size: Kernel size of separable convolutions
+ strides: Strided convolution for downsampling
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('separable_conv_block_%s' % block_id):
+ x = Activation('relu')(ip)
+ x = SeparableConv2D(
+ filters,
+ kernel_size,
+ strides=strides,
+ name='separable_conv_1_%s' % block_id,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ x)
+ x = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='separable_conv_1_bn_%s' % (block_id))(
+ x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(
+ filters,
+ kernel_size,
+ name='separable_conv_2_%s' % block_id,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ x)
+ x = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='separable_conv_2_bn_%s' % (block_id))(
+ x)
+ return x
+
+
+def _adjust_block(p, ip, filters, block_id=None):
+ """Adjusts the input `previous path` to match the shape of the `input`.
+
+ Used in situations where the output number of filters needs to be changed.
+
+ Arguments:
+ p: Input tensor which needs to be modified
+ ip: Input tensor whose shape needs to be matched
+ filters: Number of output filters to be matched
+ block_id: String block_id
+
+ Returns:
+ Adjusted Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+ img_dim = 2 if K.image_data_format() == 'channels_first' else -2
+
+ ip_shape = K.int_shape(ip)
+
+ if p is not None:
+ p_shape = K.int_shape(p)
+
+ with K.name_scope('adjust_block'):
+ if p is None:
+ p = ip
+
+ elif p_shape[img_dim] != ip_shape[img_dim]:
+ with K.name_scope('adjust_reduction_block_%s' % block_id):
+ p = Activation('relu', name='adjust_relu_1_%s' % block_id)(p)
+
+ p1 = AveragePooling2D(
+ (1, 1),
+ strides=(2, 2),
+ padding='valid',
+ name='adjust_avg_pool_1_%s' % block_id)(
+ p)
+ p1 = Conv2D(
+ filters // 2, (1, 1),
+ padding='same',
+ use_bias=False,
+ name='adjust_conv_1_%s' % block_id,
+ kernel_initializer='he_normal')(
+ p1)
+
+ p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p)
+ p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2)
+ p2 = AveragePooling2D(
+ (1, 1),
+ strides=(2, 2),
+ padding='valid',
+ name='adjust_avg_pool_2_%s' % block_id)(
+ p2)
+ p2 = Conv2D(
+ filters // 2, (1, 1),
+ padding='same',
+ use_bias=False,
+ name='adjust_conv_2_%s' % block_id,
+ kernel_initializer='he_normal')(
+ p2)
+
+ p = concatenate([p1, p2], axis=channel_dim)
+ p = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='adjust_bn_%s' % block_id)(
+ p)
+
+ elif p_shape[channel_dim] != filters:
+ with K.name_scope('adjust_projection_block_%s' % block_id):
+ p = Activation('relu')(p)
+ p = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='adjust_conv_projection_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ p)
+ p = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='adjust_bn_%s' % block_id)(
+ p)
+ return p
+
+
+def _normal_a_cell(ip, p, filters, block_id=None):
+ """Adds a Normal cell for NASNet-A (Fig. 4 in the paper).
+
+ Arguments:
+ ip: Input tensor `x`
+ p: Input tensor `p`
+ filters: Number of output filters
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('normal_A_block_%s' % block_id):
+ p = _adjust_block(p, ip, filters, block_id)
+
+ h = Activation('relu')(ip)
+ h = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='normal_conv_1_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ h)
+ h = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='normal_bn_1_%s' % block_id)(
+ h)
+
+ with K.name_scope('block_1'):
+ x1_1 = _separable_conv_block(
+ h, filters, kernel_size=(5, 5), block_id='normal_left1_%s' % block_id)
+ x1_2 = _separable_conv_block(
+ p, filters, block_id='normal_right1_%s' % block_id)
+ x1 = add([x1_1, x1_2], name='normal_add_1_%s' % block_id)
+
+ with K.name_scope('block_2'):
+ x2_1 = _separable_conv_block(
+ p, filters, (5, 5), block_id='normal_left2_%s' % block_id)
+ x2_2 = _separable_conv_block(
+ p, filters, (3, 3), block_id='normal_right2_%s' % block_id)
+ x2 = add([x2_1, x2_2], name='normal_add_2_%s' % block_id)
+
+ with K.name_scope('block_3'):
+ x3 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_left3_%s' % (block_id))(
+ h)
+ x3 = add([x3, p], name='normal_add_3_%s' % block_id)
+
+ with K.name_scope('block_4'):
+ x4_1 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_left4_%s' % (block_id))(
+ p)
+ x4_2 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_right4_%s' % (block_id))(
+ p)
+ x4 = add([x4_1, x4_2], name='normal_add_4_%s' % block_id)
+
+ with K.name_scope('block_5'):
+ x5 = _separable_conv_block(
+ h, filters, block_id='normal_left5_%s' % block_id)
+ x5 = add([x5, h], name='normal_add_5_%s' % block_id)
+
+ x = concatenate(
+ [p, x1, x2, x3, x4, x5],
+ axis=channel_dim,
+ name='normal_concat_%s' % block_id)
+ return x, ip
+
+
+def _reduction_a_cell(ip, p, filters, block_id=None):
+ """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper).
+
+ Arguments:
+ ip: Input tensor `x`
+ p: Input tensor `p`
+ filters: Number of output filters
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('reduction_A_block_%s' % block_id):
+ p = _adjust_block(p, ip, filters, block_id)
+
+ h = Activation('relu')(ip)
+ h = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='reduction_conv_1_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ h)
+ h = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='reduction_bn_1_%s' % block_id)(
+ h)
+
+ with K.name_scope('block_1'):
+ x1_1 = _separable_conv_block(
+ h,
+ filters, (5, 5),
+ strides=(2, 2),
+ block_id='reduction_left1_%s' % block_id)
+ x1_2 = _separable_conv_block(
+ p,
+ filters, (7, 7),
+ strides=(2, 2),
+ block_id='reduction_1_%s' % block_id)
+ x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % block_id)
+
+ with K.name_scope('block_2'):
+ x2_1 = MaxPooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_left2_%s' % block_id)(
+ h)
+ x2_2 = _separable_conv_block(
+ p,
+ filters, (7, 7),
+ strides=(2, 2),
+ block_id='reduction_right2_%s' % block_id)
+ x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % block_id)
+
+ with K.name_scope('block_3'):
+ x3_1 = AveragePooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_left3_%s' % block_id)(
+ h)
+ x3_2 = _separable_conv_block(
+ p,
+ filters, (5, 5),
+ strides=(2, 2),
+ block_id='reduction_right3_%s' % block_id)
+ x3 = add([x3_1, x3_2], name='reduction_add3_%s' % block_id)
+
+ with K.name_scope('block_4'):
+ x4 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='reduction_left4_%s' % block_id)(
+ x1)
+ x4 = add([x2, x4])
+
+ with K.name_scope('block_5'):
+ x5_1 = _separable_conv_block(
+ x1, filters, (3, 3), block_id='reduction_left4_%s' % block_id)
+ x5_2 = MaxPooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_right5_%s' % block_id)(
+ h)
+ x5 = add([x5_1, x5_2], name='reduction_add4_%s' % block_id)
+
+ x = concatenate(
+ [x2, x3, x4, x5],
+ axis=channel_dim,
+ name='reduction_concat_%s' % block_id)
+ return x, ip
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py
new file mode 100644
index 0000000000..aa1dec670c
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Nasnet application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class NASNetMobileTest(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.NASNetMobile(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.NASNetMobile(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1056))
+
+ def test_with_pooling(self):
+ model = keras.applications.NASNetMobile(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1056))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetMobile(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetMobile(weights='imagenet',
+ classes=2000)
+
+
+class NASNetLargeTest(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.NASNetLarge(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.NASNetLarge(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 4032))
+
+ def test_with_pooling(self):
+ model = keras.applications.NASNetLarge(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 4032))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetLarge(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetLarge(weights='imagenet',
+ classes=2000)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index 8ab46693aa..5705b3481a 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""ResNet50 model for Keras.
# Reference:
@@ -31,8 +32,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -45,7 +46,9 @@ from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
@@ -78,7 +81,8 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = Activation('relu')(x)
x = Conv2D(
- filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
+ filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
+ x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
@@ -92,7 +96,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
2)):
- """conv_block is the block that has a conv layer at shortcut.
+ """A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
@@ -100,14 +104,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
- strides: Tuple of integers.
+ strides: Strides for the first conv layer in the block.
Returns:
Output tensor for the block.
- Note that from stage 3, the first conv layer at main path is with
- strides=(2,2)
- And the shortcut should have strides=(2,2) as well
+ Note that from stage 3,
+ the first conv layer at main path is with strides=(2, 2)
+ And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if K.image_data_format() == 'channels_last':
@@ -118,13 +122,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(
- filters1, (1, 1), strides=strides,
- name=conv_name_base + '2a')(input_tensor)
+ filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(
+ input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)
x = Conv2D(
- filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
+ filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
+ x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
@@ -132,8 +137,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
shortcut = Conv2D(
- filters3, (1, 1), strides=strides,
- name=conv_name_base + '1')(input_tensor)
+ filters3, (1, 1), strides=strides, name=conv_name_base + '1')(
+ input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
@@ -152,7 +157,7 @@ def ResNet50(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -164,15 +169,15 @@ def ResNet50(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 197.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -219,15 +224,18 @@ def ResNet50(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
- x = Conv2D(64, (7, 7),
- strides=(2, 2), padding='same', name='conv1')(img_input)
+ x = Conv2D(
+ 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(
+ img_input)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
@@ -289,4 +297,5 @@ def ResNet50(include_top=True,
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
index 38dbbdc809..c91c24e6fb 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""VGG16 model for Keras.
# Reference
@@ -29,8 +30,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
@@ -42,6 +43,7 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
@@ -59,7 +61,7 @@ def VGG16(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -71,8 +73,8 @@ def VGG16(include_top=True,
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
@@ -125,48 +127,62 @@ def VGG16(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
# Block 1
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same',
- name='block1_conv1')(img_input)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv1')(
+ img_input)
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(
+ x)
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
if include_top:
@@ -215,6 +231,8 @@ def VGG16(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
+
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
index 126c64260b..223cd79d7b 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""VGG19 model for Keras.
# Reference
@@ -29,8 +30,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
@@ -42,6 +43,7 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5'
@@ -59,7 +61,7 @@ def VGG19(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -71,15 +73,15 @@ def VGG19(include_top=True,
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 48.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -125,54 +127,71 @@ def VGG19(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
# Block 1
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same',
- name='block1_conv1')(img_input)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv1')(
+ img_input)
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(
+ x)
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
if include_top:
@@ -211,6 +230,8 @@ def VGG19(include_top=True,
cache_subdir='models',
file_hash='253f8cb515780f3b799900260a226db6')
model.load_weights(weights_path)
+ if K.backend() == 'theano':
+ layer_utils.convert_all_kernels_in_model(model)
if K.image_data_format() == 'channels_first':
if include_top:
@@ -219,6 +240,8 @@ def VGG19(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
+
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py
index 8219831408..0a6eb4953a 100644
--- a/tensorflow/python/keras/_impl/keras/applications/xception.py
+++ b/tensorflow/python/keras/_impl/keras/applications/xception.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Xception V1 model for Keras.
On ImageNet, this model gets to a top-1 validation accuracy of 0.790
@@ -42,7 +43,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
@@ -74,7 +75,7 @@ def Xception(include_top=True,
on ImageNet. This model is available for TensorFlow only,
and can only be used with inputs following the TensorFlow
data format `(width, height, channels)`.
- You should set `image_data_format="channels_last"` in your Keras config
+ You should set `image_data_format='channels_last'` in your Keras config
located at ~/.keras/keras.json.
Note that the default input image size for this model is 299x299.
@@ -83,14 +84,14 @@ def Xception(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(299, 299, 3)`.
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 71.
E.g. `(150, 150, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -155,11 +156,14 @@ def Xception(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
x = Conv2D(
- 32, (3, 3), strides=(2, 2), use_bias=False,
- name='block1_conv1')(img_input)
+ 32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(
+ img_input)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
@@ -167,53 +171,65 @@ def Xception(include_top=True,
x = Activation('relu', name='block1_conv2_act')(x)
residual = Conv2D(
- 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(
- 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
+ 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(
+ x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(
- 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
+ 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(
+ x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block2_pool')(
+ x)
x = layers.add([x, residual])
residual = Conv2D(
- 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(
- 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
+ 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(
+ x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(
- 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
+ 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(
+ x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block3_pool')(
+ x)
x = layers.add([x, residual])
residual = Conv2D(
- 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(
+ x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(
+ x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block4_pool')(
+ x)
x = layers.add([x, residual])
for i in range(8):
@@ -222,46 +238,52 @@ def Xception(include_top=True,
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv2')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv3')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
x = layers.add([x, residual])
residual = Conv2D(
- 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(
+ x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(
- 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
+ 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(
+ x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block13_pool')(
+ x)
x = layers.add([x, residual])
x = SeparableConv2D(
- 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
+ 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(
+ x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
x = SeparableConv2D(
- 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
+ 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(
+ x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
@@ -303,8 +325,6 @@ def Xception(include_top=True,
if old_data_format:
K.set_image_data_format(old_data_format)
- elif weights is not None:
- model.load_weights(weights)
return model
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 9476085bd8..460c0dc5f3 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -85,7 +85,7 @@ _MANUAL_VAR_INIT = False
_FLOATX = 'float32'
# Epsilon fuzz factor used throughout the codebase.
-_EPSILON = 10e-8
+_EPSILON = 1e-7
# Default image data format, one of "channels_last", "channels_first".
_IMAGE_DATA_FORMAT = 'channels_last'
@@ -116,7 +116,7 @@ def epsilon():
Example:
```python
>>> keras.backend.epsilon()
- 1e-08
+ 1e-07
```
"""
return _EPSILON
@@ -132,7 +132,7 @@ def set_epsilon(value):
```python
>>> from keras import backend as K
>>> K.epsilon()
- 1e-08
+ 1e-07
>>> K.set_epsilon(1e-05)
>>> K.epsilon()
1e-05
@@ -295,7 +295,8 @@ def clear_session():
ops.reset_default_graph()
reset_uids()
_SESSION = None
- phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
@@ -328,7 +329,8 @@ def learning_phase():
"""
graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES:
- phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES[graph] = phase
return _GRAPH_LEARNING_PHASES[graph]
@@ -876,6 +878,8 @@ def zeros(shape, dtype=None, name=None):
Returns:
A variable (including Keras metadata), filled with `0.0`.
+ Note that if `shape` was symbolic, we cannot return a variable,
+ and will return a dynamically-shaped tensor instead.
Example:
```python
@@ -890,12 +894,14 @@ def zeros(shape, dtype=None, name=None):
if dtype is None:
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
- return variable(
- init_ops.constant_initializer(0., dtype=tf_dtype)(shape), dtype, name)
+ v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
def ones(shape, dtype=None, name=None):
- """Instantiates an all-ones tensor variable and returns it.
+ """Instantiates an all-ones variable and returns it.
Arguments:
shape: Tuple of integers, shape of returned Keras variable.
@@ -904,6 +910,8 @@ def ones(shape, dtype=None, name=None):
Returns:
A Keras variable, filled with `1.0`.
+ Note that if `shape` was symbolic, we cannot return a variable,
+ and will return a dynamically-shaped tensor instead.
Example:
```python
@@ -918,8 +926,10 @@ def ones(shape, dtype=None, name=None):
if dtype is None:
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
- return variable(
- init_ops.constant_initializer(1., dtype=tf_dtype)(shape), dtype, name)
+ v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
def eye(size, dtype=None, name=None):
@@ -1185,7 +1195,7 @@ def moving_average_update(x, value, momentum):
An Operation to update the variable.
"""
return moving_averages.assign_moving_average(
- x, value, momentum, zero_debias=False)
+ x, value, momentum, zero_debias=True)
# LINEAR ALGEBRA
@@ -1419,7 +1429,7 @@ def max(x, axis=None, keepdims=False):
Returns:
A tensor with maximum values of `x`.
"""
- return math_ops.reduce_max(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_max(x, axis, keepdims)
def min(x, axis=None, keepdims=False):
@@ -1436,7 +1446,7 @@ def min(x, axis=None, keepdims=False):
Returns:
A tensor with miminum values of `x`.
"""
- return math_ops.reduce_min(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_min(x, axis, keepdims)
def sum(x, axis=None, keepdims=False):
@@ -1453,7 +1463,7 @@ def sum(x, axis=None, keepdims=False):
Returns:
A tensor with sum of `x`.
"""
- return math_ops.reduce_sum(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_sum(x, axis, keepdims)
def prod(x, axis=None, keepdims=False):
@@ -1470,7 +1480,7 @@ def prod(x, axis=None, keepdims=False):
Returns:
A tensor with the product of elements of `x`.
"""
- return math_ops.reduce_prod(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_prod(x, axis, keepdims)
def cumsum(x, axis=0):
@@ -1515,10 +1525,10 @@ def var(x, axis=None, keepdims=False):
"""
if x.dtype.base_dtype == dtypes_module.bool:
x = math_ops.cast(x, floatx())
- m = math_ops.reduce_mean(x, axis=axis, keep_dims=True)
+ m = math_ops.reduce_mean(x, axis, True)
devs_squared = math_ops.square(x - m)
return math_ops.reduce_mean(
- devs_squared, axis=axis, keep_dims=keepdims)
+ devs_squared, axis, keepdims)
def std(x, axis=None, keepdims=False):
@@ -1546,7 +1556,7 @@ def mean(x, axis=None, keepdims=False):
axis: A list of integer. Axes to compute the mean.
keepdims: A boolean, whether to keep the dimensions or not.
If `keepdims` is `False`, the rank of the tensor is reduced
- by 1 for each entry in `axis`. If `keep_dims` is `True`,
+ by 1 for each entry in `axis`. If `keepdims` is `True`,
the reduced dimensions are retained with length 1.
Returns:
@@ -1554,7 +1564,7 @@ def mean(x, axis=None, keepdims=False):
"""
if x.dtype.base_dtype == dtypes_module.bool:
x = math_ops.cast(x, floatx())
- return math_ops.reduce_mean(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_mean(x, axis, keepdims)
def any(x, axis=None, keepdims=False):
@@ -1569,7 +1579,7 @@ def any(x, axis=None, keepdims=False):
A uint8 tensor (0s and 1s).
"""
x = math_ops.cast(x, dtypes_module.bool)
- return math_ops.reduce_any(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_any(x, axis, keepdims)
def all(x, axis=None, keepdims=False):
@@ -1584,7 +1594,7 @@ def all(x, axis=None, keepdims=False):
A uint8 tensor (0s and 1s).
"""
x = math_ops.cast(x, dtypes_module.bool)
- return math_ops.reduce_all(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_all(x, axis, keepdims)
def argmax(x, axis=-1):
@@ -1694,7 +1704,7 @@ def logsumexp(x, axis=None, keepdims=False):
Returns:
The reduced tensor.
"""
- return math_ops.reduce_logsumexp(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_logsumexp(x, axis, keepdims)
def round(x):
@@ -1884,6 +1894,108 @@ def cos(x):
return math_ops.cos(x)
+def _regular_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Non-fused version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ mean, var = nn.moments(x, reduction_axes, None, None, False)
+ normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
+ return normed, mean, var
+
+
+def _broadcast_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Non-fused, broadcast version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ mean, var = nn.moments(x, reduction_axes, None, None, False)
+ target_shape = []
+ for axis in range(ndim(x)):
+ if axis in reduction_axes:
+ target_shape.append(1)
+ else:
+ target_shape.append(array_ops.shape(x)[axis])
+ target_shape = array_ops.stack(target_shape)
+
+ broadcast_mean = array_ops.reshape(mean, target_shape)
+ broadcast_var = array_ops.reshape(var, target_shape)
+ if gamma is None:
+ broadcast_gamma = None
+ else:
+ broadcast_gamma = array_ops.reshape(gamma, target_shape)
+ if beta is None:
+ broadcast_beta = None
+ else:
+ broadcast_beta = array_ops.reshape(beta, target_shape)
+
+ normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
+ broadcast_beta, broadcast_gamma, epsilon)
+ return normed, mean, var
+
+
+def _fused_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Fused version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ if list(reduction_axes) == [0, 1, 2]:
+ normalization_axis = 3
+ tf_data_format = 'NHWC'
+ else:
+ normalization_axis = 1
+ tf_data_format = 'NCHW'
+
+ if gamma is None:
+ gamma = constant_op.constant(
+ 1.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+ if beta is None:
+ beta = constant_op.constant(
+ 0.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+
+ return nn.fused_batch_norm(
+ x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
+
+
def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
"""Computes mean and std for batch then apply batch_normalization on batch.
@@ -1898,33 +2010,19 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
Returns:
A tuple length of 3, `(normalized_tensor, mean, variance)`.
"""
- mean, var = nn.moments(
- x, reduction_axes, shift=None, name=None, keep_dims=False)
- if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
- normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
+ if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
+ if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
+ return _broadcast_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
+ return _fused_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
else:
- # need broadcasting
- target_shape = []
- for axis in range(ndim(x)):
- if axis in reduction_axes:
- target_shape.append(1)
- else:
- target_shape.append(array_ops.shape(x)[axis])
- target_shape = array_ops.stack(target_shape)
-
- broadcast_mean = array_ops.reshape(mean, target_shape)
- broadcast_var = array_ops.reshape(var, target_shape)
- if gamma is None:
- broadcast_gamma = None
+ if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
+ return _regular_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
else:
- broadcast_gamma = array_ops.reshape(gamma, target_shape)
- if beta is None:
- broadcast_beta = None
- else:
- broadcast_beta = array_ops.reshape(beta, target_shape)
- normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
- broadcast_beta, broadcast_gamma, epsilon)
- return normed, mean, var
+ return _broadcast_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
@@ -2619,7 +2717,8 @@ def rnn(step_function,
go_backwards=False,
mask=None,
constants=None,
- unroll=False):
+ unroll=False,
+ input_length=None):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -2648,6 +2747,7 @@ def rnn(step_function,
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
+ input_length: Unused; exists for API compatibility.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -2665,6 +2765,7 @@ def rnn(step_function,
ValueError: if `mask` is provided (not `None`) but states is not provided
(`len(states)` == 0).
"""
+ del input_length
ndim = len(inputs.get_shape())
if ndim < 3:
raise ValueError('Input should be at least 3D.')
@@ -3016,7 +3117,7 @@ def elu(x, alpha=1.):
Arguments:
x: A tensor or variable to compute the activation function for.
- alpha: A scalar, slope of positive section.
+ alpha: A scalar, slope of negative section.
Returns:
A tensor.
@@ -3083,7 +3184,7 @@ def categorical_crossentropy(target, output, from_logits=False):
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
output /= math_ops.reduce_sum(
- output, axis=len(output.get_shape()) - 1, keep_dims=True)
+ output, len(output.get_shape()) - 1, True)
# manual computation of crossentropy
epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
@@ -3248,6 +3349,25 @@ def in_top_k(predictions, targets, k):
# CONVOLUTIONS
+def _preprocess_conv1d_input(x, data_format):
+ """Transpose and cast the input before the conv1d.
+
+ Arguments:
+ x: input tensor.
+ data_format: string, `"channels_last"` or `"channels_first"`.
+
+ Returns:
+ A tensor.
+ """
+ tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
+ if data_format == 'channels_first':
+ if not _has_nchw_support():
+ x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
+ else:
+ tf_data_format = 'NCHW'
+ return x, tf_data_format
+
+
def _preprocess_conv2d_input(x, data_format):
"""Transpose and cast the input before the conv2d.
@@ -3461,6 +3581,66 @@ def conv2d_transpose(x,
return x
+def separable_conv1d(x,
+ depthwise_kernel,
+ pointwise_kernel,
+ strides=1,
+ padding='valid',
+ data_format=None,
+ dilation_rate=1):
+ """1D convolution with separable filters.
+
+ Arguments:
+ x: input tensor
+ depthwise_kernel: convolution kernel for the depthwise convolution.
+ pointwise_kernel: kernel for the 1x1 convolution.
+ strides: stride integer.
+ padding: string, `"same"` or `"valid"`.
+ data_format: string, `"channels_last"` or `"channels_first"`.
+ dilation_rate: integer dilation rate.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if `data_format` is neither `channels_last` or
+ `channels_first`.
+ """
+ if data_format is None:
+ data_format = image_data_format()
+ if data_format not in {'channels_first', 'channels_last'}:
+ raise ValueError('Unknown data_format ' + str(data_format))
+
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
+ padding = _preprocess_padding(padding)
+ if tf_data_format == 'NHWC':
+ spatial_start_dim = 1
+ strides = (1, 1) + strides + (1,)
+ else:
+ spatial_start_dim = 2
+ strides = (1, 1, 1) + strides
+ x = array_ops.expand_dims(x, spatial_start_dim)
+ depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
+ pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
+ dilation_rate = (1,) + dilation_rate
+
+ x = nn.separable_conv2d(
+ x,
+ depthwise_kernel,
+ pointwise_kernel,
+ strides=strides,
+ padding=padding,
+ rate=dilation_rate,
+ data_format=tf_data_format)
+
+ x = array_ops.squeeze(x, [spatial_start_dim])
+
+ if data_format == 'channels_first' and tf_data_format == 'NHWC':
+ x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
+
+ return x
+
+
def separable_conv2d(x,
depthwise_kernel,
pointwise_kernel,
@@ -3921,7 +4101,10 @@ def bias_add(x, bias, data_format=None):
elif ndim(x) == 4:
if data_format == 'channels_first':
if len(bias_shape) == 1:
- x += reshape(bias, (1, bias_shape[0], 1, 1))
+ if _has_nchw_support():
+ x = nn.bias_add(x, bias, data_format='NCHW')
+ else:
+ x += reshape(bias, (1, bias_shape[0], 1, 1))
else:
x += reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
elif data_format == 'channels_last':
@@ -4113,7 +4296,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
sparse_labels = math_ops.to_int32(
ctc_label_dense_to_sparse(y_true, label_length))
- y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
+ y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
return array_ops.expand_dims(
ctc.ctc_loss(
@@ -4148,7 +4331,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
Tensor `(top_paths, )` that contains
the log probability of each decoded sequence.
"""
- y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
+ y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
input_length = math_ops.to_int32(input_length)
if greedy:
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
index e34f1b6926..27833e368d 100644
--- a/tensorflow/python/keras/_impl/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -954,7 +954,6 @@ class BackendNNOpsTest(test.TestCase):
x = keras.backend.variable(val)
reduction_axes = (0, 2, 3)
- # case: need broadcasting
g_val = np.random.random((3,))
b_val = np.random.random((3,))
gamma = keras.backend.variable(g_val)
@@ -965,17 +964,6 @@ class BackendNNOpsTest(test.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
- # case: doesn't need broadcasting
- g_val = np.random.random((1, 3, 1, 1))
- b_val = np.random.random((1, 3, 1, 1))
- gamma = keras.backend.variable(g_val)
- beta = keras.backend.variable(b_val)
- normed, mean, var = keras.backend.normalize_batch_in_training(
- x, gamma, beta, reduction_axes, epsilon=1e-3)
- self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
- self.assertEqual(mean.get_shape().as_list(), [3,])
- self.assertEqual(var.get_shape().as_list(), [3,])
-
# case: gamma=None
gamma = None
normed, mean, var = keras.backend.normalize_batch_in_training(
diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py
index 8da3b85718..f0d9e0b0f5 100644
--- a/tensorflow/python/keras/_impl/keras/callbacks.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras callbacks: utilities called at certain points during model training.
+# pylint: disable=g-import-not-at-top
+"""Callbacks: utilities called at certain points during model training.
"""
from __future__ import absolute_import
from __future__ import division
@@ -36,12 +37,10 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
-# pylint: disable=g-import-not-at-top
try:
import requests
except ImportError:
requests = None
-# pylint: enable=g-import-not-at-top
class CallbackList(object):
@@ -109,9 +108,9 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_begin)
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
- logging.warning(
- 'Method on_batch_begin() is slow compared '
- 'to the batch update (%f). Check your callbacks.' % delta_t_median)
+ logging.warning('Method on_batch_begin() is slow compared '
+ 'to the batch update (%f). Check your callbacks.',
+ delta_t_median)
self._t_enter_batch = time.time()
def on_batch_end(self, batch, logs=None):
@@ -132,9 +131,9 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_end)
if (self._delta_t_batch > 0. and
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
- logging.warning(
- 'Method on_batch_end() is slow compared '
- 'to the batch update (%f). Check your callbacks.' % delta_t_median)
+ logging.warning('Method on_batch_end() is slow compared '
+ 'to the batch update (%f). Check your callbacks.',
+ delta_t_median)
def on_train_begin(self, logs=None):
"""Called at the beginning of training.
@@ -246,7 +245,8 @@ class BaseLogger(Callback):
class TerminateOnNaN(Callback):
- """Callback that terminates training when a NaN loss is encountered."""
+ """Callback that terminates training when a NaN loss is encountered.
+ """
def __init__(self):
super(TerminateOnNaN, self).__init__()
@@ -396,7 +396,7 @@ class ModelCheckpoint(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('ModelCheckpoint mode %s is unknown, '
- 'fallback to auto mode.' % mode)
+ 'fallback to auto mode.', (mode), RuntimeWarning)
mode = 'auto'
if mode == 'min':
@@ -423,11 +423,11 @@ class ModelCheckpoint(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Can save best model only with %s available, '
- 'skipping.' % (self.monitor))
+ 'skipping.', self.monitor, RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
- print('Epoch %05d: %s improved from %0.5f to %0.5f,'
+ print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s' % (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
@@ -437,11 +437,11 @@ class ModelCheckpoint(Callback):
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
- print('Epoch %05d: %s did not improve' % (epoch + 1,
- self.monitor))
+ print('\nEpoch %05d: %s did not improve' % (epoch + 1,
+ self.monitor))
else:
if self.verbose > 0:
- print('Epoch %05d: saving model to %s' % (epoch + 1, filepath))
+ print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
@@ -486,7 +486,7 @@ class EarlyStopping(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
- 'fallback to auto mode.' % mode)
+ 'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
@@ -514,8 +514,8 @@ class EarlyStopping(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Early stopping conditioned on metric `%s` '
- 'which is not available. Available metrics are: %s' %
- (self.monitor, ','.join(list(logs.keys()))))
+ 'which is not available. Available metrics are: %s',
+ self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
@@ -544,8 +544,6 @@ class RemoteMonitor(Callback):
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
headers: Dictionary; optional custom HTTP headers.
- Defaults to:
- `{'Accept': 'application/json', 'Content-Type': 'application/json'}`
"""
def __init__(self,
@@ -554,11 +552,7 @@ class RemoteMonitor(Callback):
field='data',
headers=None):
super(RemoteMonitor, self).__init__()
- if headers is None:
- headers = {
- 'Accept': 'application/json',
- 'Content-Type': 'application/json'
- }
+
self.root = root
self.path = path
self.field = field
@@ -588,11 +582,13 @@ class LearningRateScheduler(Callback):
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and returns a new
learning rate as output (float).
+ verbose: int. 0: quiet, 1: update messages.
"""
- def __init__(self, schedule):
+ def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
+ self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
@@ -602,6 +598,9 @@ class LearningRateScheduler(Callback):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
+ if self.verbose > 0:
+ print('\nEpoch %05d: LearningRateScheduler reducing learning '
+ 'rate to %s.' % (epoch + 1, lr))
class TensorBoard(Callback):
@@ -842,7 +841,7 @@ class ReduceLROnPlateau(Callback):
"""
if self.mode not in ['auto', 'min', 'max']:
logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
- 'fallback to auto mode.' % (self.mode))
+ 'fallback to auto mode.', self.mode, RuntimeWarning)
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
@@ -853,7 +852,6 @@ class ReduceLROnPlateau(Callback):
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
- self.lr_epsilon = self.min_lr * 1e-4
def on_train_begin(self, logs=None):
self._reset()
@@ -864,8 +862,9 @@ class ReduceLROnPlateau(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Reduce LR on plateau conditioned on metric `%s` '
- 'which is not available. Available metrics are: %s' %
- (self.monitor, ','.join(list(logs.keys()))))
+ 'which is not available. Available metrics are: %s',
+ self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
+
else:
if self.in_cooldown():
self.cooldown_counter -= 1
@@ -877,13 +876,13 @@ class ReduceLROnPlateau(Callback):
elif not self.in_cooldown():
if self.wait >= self.patience:
old_lr = float(K.get_value(self.model.optimizer.lr))
- if old_lr > self.min_lr + self.lr_epsilon:
+ if old_lr > self.min_lr:
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
K.set_value(self.model.optimizer.lr, new_lr)
if self.verbose > 0:
- print('\nEpoch %05d: reducing learning rate to %s.' % (epoch,
- new_lr))
+ print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
+ 'rate to %s.' % (epoch + 1, new_lr))
self.cooldown_counter = self.cooldown
self.wait = 0
self.wait += 1
@@ -899,10 +898,11 @@ class CSVLogger(Callback):
including 1D iterables such as np.ndarray.
Example:
- ```python
- csv_logger = CSVLogger('training.log')
- model.fit(X_train, Y_train, callbacks=[csv_logger])
- ```
+
+ ```python
+ csv_logger = CSVLogger('training.log')
+ model.fit(X_train, Y_train, callbacks=[csv_logger])
+ ```
Arguments:
filename: filename of the csv file, e.g. 'run/log.csv'.
@@ -942,12 +942,14 @@ class CSVLogger(Callback):
else:
return k
+ if self.keys is None:
+ self.keys = sorted(logs.keys())
+
if self.model.stop_training:
# We set NA so that csv parsers do not fail for this last epoch.
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
if not self.writer:
- self.keys = sorted(logs.keys())
class CustomDialect(csv.excel):
delimiter = self.sep
@@ -993,32 +995,32 @@ class LambdaCallback(Callback):
Example:
- ```python
- # Print the batch number at the beginning of every batch.
- batch_print_callback = LambdaCallback(
- on_batch_begin=lambda batch,logs: print(batch))
-
- # Stream the epoch loss to a file in JSON format. The file content
- # is not well-formed JSON but rather has a JSON object per line.
- import json
- json_log = open('loss_log.json', mode='wt', buffering=1)
- json_logging_callback = LambdaCallback(
- on_epoch_end=lambda epoch, logs: json_log.write(
- json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
- on_train_end=lambda logs: json_log.close()
- )
-
- # Terminate some processes after having finished model training.
- processes = ...
- cleanup_callback = LambdaCallback(
- on_train_end=lambda logs: [
- p.terminate() for p in processes if p.is_alive()])
-
- model.fit(...,
- callbacks=[batch_print_callback,
- json_logging_callback,
- cleanup_callback])
- ```
+ ```python
+ # Print the batch number at the beginning of every batch.
+ batch_print_callback = LambdaCallback(
+ on_batch_begin=lambda batch,logs: print(batch))
+
+ # Stream the epoch loss to a file in JSON format. The file content
+ # is not well-formed JSON but rather has a JSON object per line.
+ import json
+ json_log = open('loss_log.json', mode='wt', buffering=1)
+ json_logging_callback = LambdaCallback(
+ on_epoch_end=lambda epoch, logs: json_log.write(
+ json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
+ on_train_end=lambda logs: json_log.close()
+ )
+
+ # Terminate some processes after having finished model training.
+ processes = ...
+ cleanup_callback = LambdaCallback(
+ on_train_end=lambda logs: [
+ p.terminate() for p in processes if p.is_alive()])
+
+ model.fit(...,
+ callbacks=[batch_print_callback,
+ json_logging_callback,
+ cleanup_callback])
+ ```
"""
def __init__(self,
diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py
index e58e3b0377..4b051c93f3 100644
--- a/tensorflow/python/keras/_impl/keras/constraints.py
+++ b/tensorflow/python/keras/_impl/keras/constraints.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Constraints: functions that impose constraints on weights values.
+# pylint: disable=invalid-name
+"""Constraints: functions that impose constraints on weight values.
"""
from __future__ import absolute_import
from __future__ import division
@@ -54,10 +55,6 @@ class MaxNorm(Constraint):
to constrain the weights of each filter tensor of size
`(rows, cols, input_depth)`.
- References:
- - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- Srivastava, Hinton, et al.
- 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
"""
def __init__(self, max_value=2, axis=0):
@@ -79,7 +76,7 @@ class NonNeg(Constraint):
"""
def __call__(self, w):
- w *= K.cast(w >= 0., K.floatx())
+ w *= K.cast(K.greater_equal(w, 0.), K.floatx())
return w
@@ -132,7 +129,7 @@ class MinMaxNorm(Constraint):
has shape `(input_dim, output_dim)`,
set `axis` to `0` to constrain each weight vector
of length `(input_dim,)`.
- In a `Conv2D` layer with `dim_ordering="channels_last"`,
+ In a `Conv2D` layer with `data_format="channels_last"`,
the weight tensor has shape
`(rows, cols, input_depth, output_depth)`,
set `axis` to `[0, 1, 2]`
@@ -148,8 +145,9 @@ class MinMaxNorm(Constraint):
def __call__(self, w):
norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True))
- desired = (self.rate * K.clip(norms, self.min_value, self.max_value) +
- (1 - self.rate) * norms)
+ desired = (
+ self.rate * K.clip(norms, self.min_value, self.max_value) +
+ (1 - self.rate) * norms)
w *= (desired / (K.epsilon() + norms))
return w
@@ -164,13 +162,15 @@ class MinMaxNorm(Constraint):
# Aliases.
-# pylint: disable=invalid-name
max_norm = MaxNorm
non_neg = NonNeg
unit_norm = UnitNorm
min_max_norm = MinMaxNorm
-# pylint: enable=invalid-name
+# Legacy aliases.
+maxnorm = max_norm
+nonneg = non_neg
+unitnorm = unit_norm
def serialize(constraint):
diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
index 5d5d2c4f75..cfd7df61d5 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
@@ -23,25 +23,25 @@ import numpy as np
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-def load_data(path='boston_housing.npz', seed=113, test_split=0.2):
+def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
"""Loads the Boston Housing dataset.
Arguments:
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
+ test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
- test_split: fraction of the data to reserve as test set.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
assert 0 <= test_split < 1
- fh = 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5'
path = get_file(
path,
origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz',
- file_hash=fh)
+ file_hash=
+ 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5')
f = np.load(path)
x = f['x']
y = f['y']
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar.py b/tensorflow/python/keras/_impl/keras/datasets/cifar.py
index 564709c0ee..7ada3340a5 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities used by the CIFAR10 and CIFAR100 datasets.
+"""Utilities common to CIFAR10 and CIFAR100 datasets.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
index 7905da66c1..fb9d98d42c 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""CIFAR10 small image classification dataset.
+"""CIFAR10 small images classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
index b69c0724c5..95aace599a 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""CIFAR100 small image classification dataset.
+"""CIFAR100 small images classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -40,7 +40,7 @@ def load_data(label_mode='fine'):
ValueError: in case of invalid `label_mode`.
"""
if label_mode not in ['fine', 'coarse']:
- raise ValueError('label_mode must be one of "fine" "coarse".')
+ raise ValueError('`label_mode` must be one of `"fine"`, `"coarse"`.')
dirname = 'cifar-100-python'
origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
index 17be684e4f..b9ae41a0d4 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
@@ -20,7 +20,9 @@ from __future__ import print_function
import gzip
import os
+
import numpy as np
+
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -38,9 +40,8 @@ def load_data():
]
paths = []
- for given_file in files:
- paths.append(
- get_file(given_file, origin=base + given_file, cache_subdir=dirname))
+ for fname in files:
+ paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
index 7d55ebc8e4..880c9c821b 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""IMDB movie review sentiment classification dataset.
+"""IMDB sentiment classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -21,9 +21,10 @@ from __future__ import print_function
import json
import numpy as np
-from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
def load_data(path='imdb.npz',
@@ -33,7 +34,8 @@ def load_data(path='imdb.npz',
seed=113,
start_char=1,
oov_char=2,
- index_from=3):
+ index_from=3,
+ **kwargs):
"""Loads the IMDB dataset.
Arguments:
@@ -50,6 +52,7 @@ def load_data(path='imdb.npz',
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
+ **kwargs: Used for backwards compatibility.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -64,14 +67,21 @@ def load_data(path='imdb.npz',
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `load_data` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
path = get_file(
path,
origin='https://s3.amazonaws.com/text-datasets/imdb.npz',
file_hash='599dadb1135973df5b59232a0e9a887c')
- f = np.load(path)
- x_train, labels_train = f['x_train'], f['y_train']
- x_test, labels_test = f['x_test'], f['y_test']
- f.close()
+ with np.load(path) as f:
+ x_train, labels_train = f['x_train'], f['y_train']
+ x_test, labels_test = f['x_test'], f['y_test']
np.random.seed(seed)
indices = np.arange(len(x_train))
@@ -93,14 +103,7 @@ def load_data(path='imdb.npz',
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
- new_xs = []
- new_labels = []
- for x, y in zip(xs, labels):
- if len(x) < maxlen:
- new_xs.append(x)
- new_labels.append(y)
- xs = new_xs
- labels = new_labels
+ xs, labels = _remove_long_seq(maxlen, xs, labels)
if not xs:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
@@ -112,23 +115,15 @@ def load_data(path='imdb.npz',
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
- xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x]
- for x in xs]
+ xs = [
+ [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
+ ]
else:
- new_xs = []
- for x in xs:
- nx = []
- for w in x:
- if skip_top <= w < num_words:
- nx.append(w)
- new_xs.append(nx)
- xs = new_xs
-
- x_train = np.array(xs[:len(x_train)])
- y_train = np.array(labels[:len(x_train)])
+ xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
- x_test = np.array(xs[len(x_train):])
- y_test = np.array(labels[len(x_train):])
+ idx = len(x_train)
+ x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
+ x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
@@ -144,7 +139,8 @@ def get_word_index(path='imdb_word_index.json'):
"""
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json')
+ origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json',
+ file_hash='bfafd718b763782e994055a2d397834f')
f = open(path)
data = json.load(f)
f.close()
diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
index e98f29537f..ec12a31dcf 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""MNIST handwritten digits classification dataset.
+"""MNIST handwritten digits dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -38,9 +38,7 @@ def load_data(path='mnist.npz'):
origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path)
- x_train = f['x_train']
- y_train = f['y_train']
- x_test = f['x_test']
- y_test = f['y_test']
+ x_train, y_train = f['x_train'], f['y_train']
+ x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
index 3fed12b59f..95cf8852a9 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Reuters newswire topic classification dataset.
+"""Reuters topic classification dataset.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -22,9 +21,10 @@ from __future__ import print_function
import json
import numpy as np
-from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
def load_data(path='reuters.npz',
@@ -35,7 +35,8 @@ def load_data(path='reuters.npz',
seed=113,
start_char=1,
oov_char=2,
- index_from=3):
+ index_from=3,
+ **kwargs):
"""Loads the Reuters newswire classification dataset.
Arguments:
@@ -53,6 +54,7 @@ def load_data(path='reuters.npz',
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
+ **kwargs: Used for backwards compatibility.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -63,14 +65,20 @@ def load_data(path='reuters.npz',
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `load_data` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
path = get_file(
path,
origin='https://s3.amazonaws.com/text-datasets/reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
- npzfile = np.load(path)
- xs = npzfile['x']
- labels = npzfile['y']
- npzfile.close()
+ with np.load(path) as f:
+ xs, labels = f['x'], f['y']
np.random.seed(seed)
indices = np.arange(len(xs))
@@ -78,22 +86,13 @@ def load_data(path='reuters.npz',
xs = xs[indices]
labels = labels[indices]
- np.random.shuffle(labels)
-
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
- new_xs = []
- new_labels = []
- for x, y in zip(xs, labels):
- if len(x) < maxlen:
- new_xs.append(x)
- new_labels.append(y)
- xs = new_xs
- labels = new_labels
+ xs, labels = _remove_long_seq(maxlen, xs, labels)
if not num_words:
num_words = max([max(x) for x in xs])
@@ -102,23 +101,13 @@ def load_data(path='reuters.npz',
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
- xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x]
- for x in xs]
+ xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
else:
- new_xs = []
- for x in xs:
- nx = []
- for w in x:
- if skip_top <= w < num_words:
- nx.append(w)
- new_xs.append(nx)
- xs = new_xs
-
- x_train = np.array(xs[:int(len(xs) * (1 - test_split))])
- y_train = np.array(labels[:int(len(xs) * (1 - test_split))])
-
- x_test = np.array(xs[int(len(xs) * (1 - test_split)):])
- y_test = np.array(labels[int(len(xs) * (1 - test_split)):])
+ xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
+
+ idx = int(len(xs) * (1 - test_split))
+ x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
+ x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
index d6e0be8e43..64aa868f38 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -27,6 +27,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
@@ -712,8 +713,8 @@ class Network(tf_network.GraphNetwork, Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
- self.internal_input_shapes = [K.int_shape(x) for x in self.inputs]
- self.internal_output_shapes = [K.int_shape(x) for x in self.outputs]
+ self._internal_input_shapes = [K.int_shape(x) for x in self.inputs]
+ self._internal_output_shapes = [K.int_shape(x) for x in self.outputs]
@property
def uses_learning_phase(self):
@@ -1303,18 +1304,17 @@ def preprocess_weights_for_loading(layer,
Returns:
A list of weights values (Numpy arrays).
"""
- if original_keras_version == '1':
- if layer.__class__.__name__ == 'Bidirectional':
- num_weights_per_layer = len(weights) // 2
-
- forward_weights = preprocess_weights_for_loading(
- layer.forward_layer, weights[:num_weights_per_layer],
- original_keras_version, original_backend)
- backward_weights = preprocess_weights_for_loading(
- layer.backward_layer, weights[num_weights_per_layer:],
- original_keras_version, original_backend)
- weights = forward_weights + backward_weights
+ if layer.__class__.__name__ == 'Bidirectional':
+ num_weights_per_layer = len(weights) // 2
+ forward_weights = preprocess_weights_for_loading(
+ layer.forward_layer, weights[:num_weights_per_layer],
+ original_keras_version, original_backend)
+ backward_weights = preprocess_weights_for_loading(
+ layer.backward_layer, weights[num_weights_per_layer:],
+ original_keras_version, original_backend)
+ weights = forward_weights + backward_weights
+ if original_keras_version == '1':
if layer.__class__.__name__ == 'TimeDistributed':
weights = preprocess_weights_for_loading(
layer.layer, weights, original_keras_version, original_backend)
@@ -1418,7 +1418,7 @@ def preprocess_weights_for_loading(layer,
conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
if layer.__class__.__name__ in conv_layers:
- if original_backend and K.backend() != original_backend:
+ if original_backend == 'theano':
weights[0] = conv_utils.convert_kernel(weights[0])
if layer.__class__.__name__ == 'ConvLSTM2D':
weights[1] = conv_utils.convert_kernel(weights[1])
@@ -1427,10 +1427,9 @@ def preprocess_weights_for_loading(layer,
if layer.__class__.__name__ == 'ConvLSTM2D':
weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
- # convert the weights of CuDNNLSTM so that they could be loaded into LSTM
+ # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM
if layer.__class__.__name__ == 'LSTM' and len(weights) == 3:
- # determine if we're loading a CuDNNLSTM layer from the number of bias
- # weights:
+ # Determine if loading a CuDNNLSTM layer from the number of bias weights:
# CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
# if there's no bias weight in the file, skip this conversion
units = weights[1].shape[0]
@@ -1572,3 +1571,31 @@ def load_weights_from_hdf5_group_by_name(f, layers):
for i in range(len(weight_values)):
weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
+
+
+def shape_type_conversion(fn):
+ """Decorator that handles tuple/TensorShape conversion.
+
+ Used in `compute_output_shape` and `build`.
+
+ Arguments:
+ fn: function to wrap.
+
+ Returns:
+ Wrapped function.
+ """
+
+ def wrapper(instance, input_shape):
+ if input_shape is not None:
+ if isinstance(input_shape, list):
+ input_shape = [
+ tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
+ else:
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ output_shape = fn(instance, input_shape)
+ if output_shape is not None:
+ if isinstance(output_shape, list):
+ return [tensor_shape.TensorShape(x) for x in output_shape]
+ return tensor_shape.TensorShape(output_shape)
+
+ return wrapper
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index debea2503e..699ae2edf0 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras training and evaluation routines.
+"""Training-related part of the Keras engine.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -35,6 +34,11 @@ from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
from tensorflow.python.platform import tf_logging as logging
+try:
+ from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ issparse = None
+
def _standardize_input_data(data,
names,
@@ -70,89 +74,72 @@ def _standardize_input_data(data,
return []
if data is None:
return [None for _ in range(len(names))]
+
if isinstance(data, dict):
- for key, value in data.items():
- if value.__class__.__name__ == 'DataFrame':
- data[key] = value.values
- arrays = []
- for name in names:
- if name not in data:
- raise ValueError('No data provided for "' + name +
- '". Need data for each key in: ' + str(names))
- arrays.append(data[name])
+ try:
+ data = [
+ data[x].values
+ if data[x].__class__.__name__ == 'DataFrame' else data[x]
+ for x in names
+ ]
+ data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data]
+ except KeyError as e:
+ raise ValueError('No data provided for "' + e.args[0] + '". Need data '
+ 'for each key in: ' + str(names))
elif isinstance(data, list):
- for key, value in enumerate(data):
- if value.__class__.__name__ == 'DataFrame':
- data[key] = value.values
- if len(data) != len(names):
- if data and hasattr(data[0], 'shape'):
- raise ValueError(
- 'Error when checking model ' + exception_prefix +
- ': the list of Numpy arrays '
- 'that you are passing to your model '
- 'is not the size the model expected. '
- 'Expected to see ' + str(len(names)) + ' array(s), but instead got '
- 'the following list of ' + str(len(data)) + ' arrays: ' +
- str(data)[:200] + '...')
- else:
- if len(names) == 1:
- data = [np.asarray(data)]
- else:
- raise ValueError('Error when checking model ' + exception_prefix +
- ': you are passing a list as '
- 'input to your model, '
- 'but the model expects '
- 'a list of ' + str(len(names)) +
- ' Numpy arrays instead. '
- 'The list you passed was: ' + str(data)[:200])
- arrays = data
- elif data.__class__.__name__ == 'DataFrame':
- # test if data is a DataFrame, without pandas installed
- arrays = data.values
+ data = [
+ x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
+ ]
+ data = [
+ np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x
+ for x in data
+ ]
else:
- if not hasattr(data, 'shape'):
+ data = data.values if data.__class__.__name__ == 'DataFrame' else data
+ data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data]
+
+ if len(data) != len(names):
+ if data and hasattr(data[0], 'shape'):
+ raise ValueError('Error when checking model ' + exception_prefix +
+ ': the list of Numpy arrays that you are passing to '
+ 'your model is not the size the model expected. '
+ 'Expected to see ' + str(len(names)) + ' array(s), '
+ 'but instead got the following list of ' +
+ str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
+ elif len(names) > 1:
+ raise ValueError(
+ 'Error when checking model ' + exception_prefix +
+ ': you are passing a list as input to your model, '
+ 'but the model expects a list of ' + str(len(names)) +
+ ' Numpy arrays instead. The list you passed was: ' + str(data)[:200])
+ elif len(data) == 1 and not hasattr(data[0], 'shape'):
raise TypeError('Error when checking model ' + exception_prefix +
- ': data should be a Numpy array, '
- 'or list/dict of Numpy arrays. '
- 'Found: ' + str(data)[:200] + '...')
- if len(names) > 1:
- # Case: model expects multiple inputs but only received
- # a single Numpy array.
- raise ValueError('The model expects ' + str(len(names)) + ' ' +
- exception_prefix +
- ' arrays, but only received one array. '
- 'Found: array with shape ' + str(data.shape))
- arrays = [data]
-
- # Make arrays at least 2D.
- for i in range(len(names)):
- array = arrays[i]
- if len(array.shape) == 1:
- array = np.expand_dims(array, 1)
- arrays[i] = array
+ ': data should be a Numpy array, or list/dict of '
+ 'Numpy arrays. Found: ' + str(data)[:200] + '...')
+ elif len(names) == 1:
+ data = [np.asarray(data)]
# Check shapes compatibility.
if shapes:
for i in range(len(names)):
- if shapes[i] is None:
- continue
- array = arrays[i]
- if len(array.shape) != len(shapes[i]):
- raise ValueError(
- 'Error when checking ' + exception_prefix + ': expected ' + names[i]
- + ' to have ' + str(len(shapes[i])) +
- ' dimensions, but got array with shape ' + str(array.shape))
- for j, (dim, ref_dim) in enumerate(zip(array.shape, shapes[i])):
- if not j and not check_batch_axis:
- # skip the first axis
- continue
- if ref_dim:
- if ref_dim != dim:
- raise ValueError('Error when checking ' + exception_prefix +
- ': expected ' + names[i] + ' to have shape ' +
- str(shapes[i]) + ' but got array with shape ' +
- str(array.shape))
- return arrays
+ if shapes[i] is not None:
+ data_shape = data[i].shape
+ shape = shapes[i]
+ if data[i].ndim != len(shape):
+ raise ValueError('Error when checking ' + exception_prefix +
+ ': expected ' + names[i] + ' to have ' +
+ str(len(shape)) + ' dimensions, but got array '
+ 'with shape ' + str(data_shape))
+ if not check_batch_axis:
+ data_shape = data_shape[1:]
+ shape = shape[1:]
+ for dim, ref_dim in zip(data_shape, shape):
+ if ref_dim != dim and ref_dim:
+ raise ValueError(
+ 'Error when checking ' + exception_prefix + ': expected ' +
+ names[i] + ' to have shape ' + str(shape) +
+ ' but got array with shape ' + str(data_shape))
+ return data
def _standardize_sample_or_class_weights(x_weight, output_names, weight_type):
@@ -193,10 +180,10 @@ def _standardize_sample_or_class_weights(x_weight, output_names, weight_type):
x_weights.append(x_weight.get(name))
return x_weights
else:
- raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
- 'should be either a list or a dict. '
- 'Provided `' + weight_type + '` type not understood: ' +
- str(x_weight))
+ raise TypeError(
+ 'The model has multiple outputs, so `' + weight_type + '` '
+ 'should be either a list or a dict. '
+ 'Provided `' + weight_type + '` type not understood: ' + str(x_weight))
def _standardize_class_weights(class_weight, output_names):
@@ -234,12 +221,12 @@ def _check_array_lengths(inputs, targets, weights=None):
set_w = set_of_lengths(weights)
if len(set_x) > 1:
raise ValueError('All input arrays (x) should have '
- 'the same number of samples. Got array shapes: ' + str(
- [x.shape for x in inputs]))
+ 'the same number of samples. Got array shapes: ' +
+ str([x.shape for x in inputs]))
if len(set_y) > 1:
raise ValueError('All target arrays (y) should have '
- 'the same number of samples. Got array shapes: ' + str(
- [y.shape for y in targets]))
+ 'the same number of samples. Got array shapes: ' +
+ str([y.shape for y in targets]))
if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
raise ValueError('Input arrays should have '
'the same number of samples as target arrays. '
@@ -247,8 +234,8 @@ def _check_array_lengths(inputs, targets, weights=None):
'and ' + str(list(set_y)[0]) + ' target samples.')
if len(set_w) > 1:
raise ValueError('All sample_weight arrays should have '
- 'the same number of samples. Got array shapes: ' + str(
- [w.shape for w in weights]))
+ 'the same number of samples. Got array shapes: ' +
+ str([w.shape for w in weights]))
if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
raise ValueError('Sample_weight arrays should have '
'the same number of samples as target arrays. Got ' +
@@ -528,16 +515,16 @@ def _standardize_weights(y,
if sample_weight is not None:
if len(sample_weight.shape) > len(y.shape):
- raise ValueError('Found a sample_weight with shape' +
- str(sample_weight.shape) + '.'
- 'Expected sample_weight with rank '
- 'less than or equal to ' + str(len(y.shape)))
+ raise ValueError(
+ 'Found a sample_weight with shape' + str(sample_weight.shape) + '.'
+ 'Expected sample_weight with rank '
+ 'less than or equal to ' + str(len(y.shape)))
if y.shape[:sample_weight.ndim] != sample_weight.shape:
- raise ValueError('Found a sample_weight array with shape ' +
- str(sample_weight.shape) + ' for an input with shape ' +
- str(y.shape) + '. '
- 'sample_weight cannot be broadcast.')
+ raise ValueError(
+ 'Found a sample_weight array with shape ' + str(sample_weight.shape) +
+ ' for an input with shape ' + str(y.shape) + '. '
+ 'sample_weight cannot be broadcast.')
return sample_weight
elif isinstance(class_weight, dict):
if len(y.shape) > 2:
@@ -632,7 +619,6 @@ class Model(Network):
"""
loss = loss or {}
self.optimizer = optimizers.get(optimizer)
- self.sample_weight_mode = sample_weight_mode
self.loss = loss
self.loss_weights = loss_weights
self.sample_weight_mode = sample_weight_mode
@@ -641,10 +627,10 @@ class Model(Network):
if isinstance(loss, dict):
for name in loss:
if name not in self.output_names:
- raise ValueError('Unknown entry in loss '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in loss '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
loss_functions = []
for name in self.output_names:
if name not in loss:
@@ -657,7 +643,7 @@ class Model(Network):
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
raise ValueError('When passing a list as loss, '
- 'it should have one entry per model output. '
+ 'it should have one entry per model outputs. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss=' + str(loss))
loss_functions = [losses.get(l) for l in loss]
@@ -690,20 +676,20 @@ class Model(Network):
elif isinstance(loss_weights, dict):
for name in loss_weights:
if name not in self.output_names:
- raise ValueError('Unknown entry in loss_weights '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in loss_weights '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
loss_weights_list = []
for name in self.output_names:
loss_weights_list.append(loss_weights.get(name, 1.))
elif isinstance(loss_weights, list):
if len(loss_weights) != len(self.outputs):
- raise ValueError('When passing a list as loss_weights, '
- 'it should have one entry per model output. '
- 'The model has ' + str(len(self.outputs)) +
- ' outputs, but you passed loss_weights=' +
- str(loss_weights))
+ raise ValueError(
+ 'When passing a list as loss_weights, '
+ 'it should have one entry per model output. '
+ 'The model has ' + str(len(self.outputs)) +
+ ' outputs, but you passed loss_weights=' + str(loss_weights))
loss_weights_list = loss_weights
else:
raise TypeError('Could not interpret loss_weights argument: ' +
@@ -715,22 +701,22 @@ class Model(Network):
if target_tensors is not None:
if isinstance(target_tensors, list):
if len(target_tensors) != len(self.outputs):
- raise ValueError('When passing a list as `target_tensors`, '
- 'it should have one entry per model output. '
- 'The model has ' + str(len(self.outputs)) +
- ' outputs, but you passed target_tensors=' +
- str(target_tensors))
+ raise ValueError(
+ 'When passing a list as `target_tensors`, '
+ 'it should have one entry per model output. '
+ 'The model has ' + str(len(self.outputs)) +
+ ' outputs, but you passed target_tensors=' + str(target_tensors))
elif isinstance(target_tensors, dict):
for name in target_tensors:
if name not in self.output_names:
- raise ValueError('Unknown entry in `target_tensors` '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
- target_tensors_ = []
+ raise ValueError(
+ 'Unknown entry in `target_tensors` '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
+ tmp_target_tensors = []
for name in self.output_names:
- target_tensors_.append(target_tensors.get(name, None))
- target_tensors = target_tensors_
+ tmp_target_tensors.append(target_tensors.get(name, None))
+ target_tensors = tmp_target_tensors
else:
raise TypeError('Expected `target_tensors` to be '
'a list or dict, but got:', target_tensors)
@@ -738,7 +724,7 @@ class Model(Network):
if i in skip_target_indices:
self.targets.append(None)
else:
- shape = self.internal_output_shapes[i]
+ shape = self._internal_output_shapes[i]
name = self.output_names[i]
if target_tensors is not None:
target = target_tensors[i]
@@ -766,19 +752,19 @@ class Model(Network):
if isinstance(sample_weight_mode, dict):
for name in sample_weight_mode:
if name not in self.output_names:
- raise ValueError('Unknown entry in '
- 'sample_weight_mode dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in '
+ 'sample_weight_mode dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
for i, name in enumerate(self.output_names):
if i in skip_target_weighing_indices:
weight = None
sample_weight_modes.append(None)
else:
if name not in sample_weight_mode:
- raise ValueError('Output "' + name +
- '" missing from sample_weight_modes '
- 'dictionary')
+ raise ValueError(
+ 'Output "' + name + '" missing from sample_weight_modes '
+ 'dictionary')
if sample_weight_mode.get(name) == 'temporal':
weight = K.placeholder(ndim=2, name=name + '_sample_weights')
sample_weight_modes.append('temporal')
@@ -894,23 +880,36 @@ class Model(Network):
metric_name_prefix = 'weighted_' if weights is not None else ''
for metric in metrics:
- if metric == 'accuracy' or metric == 'acc':
- # custom handling of accuracy
+ if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+ # custom handling of accuracy/crossentropy
# (because of class mode duality)
- output_shape = self.internal_output_shapes[i]
+ output_shape = self._internal_output_shapes[i]
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
- # case: binary accuracy
- acc_fn = metrics_module.binary_accuracy
+ # case: binary accuracy/crossentropy
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.binary_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.binary_crossentropy
elif self.loss_functions[
i] == losses.sparse_categorical_crossentropy:
- # case: categorical accuracy with sparse targets
- acc_fn = metrics_module.sparse_categorical_accuracy
+ # case: categorical accuracy/crossentropy with sparse targets
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.sparse_categorical_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.sparse_categorical_crossentropy
else:
- acc_fn = metrics_module.categorical_accuracy
-
+ # case: categorical accuracy/crossentropy
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.categorical_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.categorical_crossentropy
+ if metric in ('accuracy', 'acc'):
+ suffix = 'acc'
+ elif metric in ('crossentropy', 'ce'):
+ suffix = 'ce'
weighted_metric_fn = _weighted_masked_objective(acc_fn)
- metric_name = metric_name_prefix + 'acc'
+ metric_name = metric_name_prefix + suffix
else:
metric_fn = metrics_module.get(metric)
weighted_metric_fn = _weighted_masked_objective(metric_fn)
@@ -949,7 +948,7 @@ class Model(Network):
"""Check trainable weights count consistency.
This will raise a warning if `trainable_weights` and
- `_collected_trainable_weights` are consistent (i.e. have the same
+ `_collected_trainable_weights` are inconsistent (i.e. have different
number of parameters).
Inconsistency will typically arise when one modifies `model.trainable`
without calling `model.compile` again.
@@ -959,9 +958,10 @@ class Model(Network):
if len(self.trainable_weights) != len(self._collected_trainable_weights):
logging.warning(
- 'Discrepancy between trainable weights and collected trainable'
- ' weights, did you set `model.trainable` without calling'
- ' `model.compile` after ?')
+ UserWarning(
+ 'Discrepancy between trainable weights and collected trainable'
+ ' weights, did you set `model.trainable` without calling'
+ ' `model.compile` after ?'))
def _make_train_function(self):
if not hasattr(self, 'train_function'):
@@ -1050,18 +1050,21 @@ class Model(Network):
processed based on the size of the first dimension of the
first input numpy array. When steps is not `None` and
`batch_size` is `None`, returns `None`.
+
+ Raises:
+ ValueError: In case of invalid arguments.
"""
if steps is not None:
num_samples = None
if batch_size is not None:
- raise ValueError('If ' + steps_name +
- ' is set, the `batch_size` must be None.')
+ raise ValueError(
+ 'If ' + steps_name + ' is set, the `batch_size` must be None.')
elif ins and hasattr(ins[0], 'shape'):
num_samples = ins[0].shape[0]
else:
- raise ValueError('Either the input data should have '
- 'a defined shape, or ' + steps_name +
- ' should be specified.')
+ raise ValueError(
+ 'Either the input data should have '
+ 'a defined shape, or ' + steps_name + ' should be specified.')
return num_samples
def _fit_loop(self,
@@ -1104,31 +1107,33 @@ class Model(Network):
steps_per_epoch: Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. Ignored with the default value of `None`.
- validation_steps: Number of steps to run validation for (only if doing
- validation from data tensors). Ignored with default value of `None`.
+ validation_steps: Number of steps to run validation for
+ (only if doing validation from data tensors).
+ Ignored with the default value of `None`.
Returns:
`History` object.
Raises:
- ValueError: In case of invalid argument values.
+ ValueError: in case of invalid arguments.
"""
do_validation = False
if val_f and val_ins:
do_validation = True
- if (verbose and ins and
- hasattr(ins[0], 'shape') and hasattr(val_ins[0], 'shape')):
+ if verbose and ins and hasattr(ins[0], 'shape') and hasattr(
+ val_ins[0], 'shape'):
print('Train on %d samples, validate on %d samples' %
(ins[0].shape[0], val_ins[0].shape[0]))
if validation_steps:
- if steps_per_epoch is None:
- raise ValueError('Can only use `validation_steps` when doing step-wise '
- 'training, i.e. `steps_per_epoch` must be set.')
do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` '
+ 'when doing step-wise '
+ 'training, i.e. `steps_per_epoch` '
+ 'must be set.')
num_train_samples = self._check_num_samples(
ins, batch_size, steps_per_epoch, 'steps_per_epoch')
-
if num_train_samples is not None:
index_array = np.arange(num_train_samples)
@@ -1165,6 +1170,13 @@ class Model(Network):
for cbk in callbacks:
cbk.validation_data = val_ins
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
@@ -1220,6 +1232,9 @@ class Model(Network):
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
outs = f(ins_batch)
if not isinstance(outs, list):
outs = [outs]
@@ -1268,6 +1283,13 @@ class Model(Network):
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)
+
+ indices_for_conversion_to_dense = []
+ for i in range(len(self._feed_inputs)):
+ if (issparse is not None and issparse(ins[i]) and
+ not K.is_sparse(self._feed_inputs[i])):
+ indices_for_conversion_to_dense.append(i)
+
if steps is not None:
# Step-based predictions.
# Since we do not know how many samples
@@ -1305,6 +1327,9 @@ class Model(Network):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
batch_outs = f(ins_batch)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
@@ -1341,12 +1366,19 @@ class Model(Network):
"""
num_samples = self._check_num_samples(ins, batch_size, steps, 'steps')
outs = []
-
if verbose == 1:
if steps is not None:
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)
+
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
if steps is not None:
for step in range(steps):
batch_outs = f(ins)
@@ -1365,8 +1397,6 @@ class Model(Network):
for i in range(len(outs)):
outs[i] /= steps
else:
- if verbose == 1:
- progbar = Progbar(target=num_samples)
batches = _make_batches(num_samples, batch_size)
index_array = np.arange(num_samples)
for batch_index, (batch_start, batch_end) in enumerate(batches):
@@ -1376,6 +1406,8 @@ class Model(Network):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
batch_outs = f(ins_batch)
if isinstance(batch_outs, list):
@@ -1484,7 +1516,8 @@ class Model(Network):
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
- validation_steps=None):
+ validation_steps=None,
+ **kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Arguments:
@@ -1501,10 +1534,9 @@ class Model(Network):
dictionary mapping output names to Numpy arrays.
`y` can be `None` (default) if feeding from
TensorFlow data tensors.
- Can be `None` (default) if feeding from framework-native tensors.
batch_size: Integer or `None`.
Number of samples per gradient update.
- If unspecified, it will default to 32.
+ If unspecified, `batch_size` will default to 32.
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1513,7 +1545,7 @@ class Model(Network):
The model is not trained for a number of iterations
given by `epochs`, but merely until the epoch
of index `epochs` is reached.
- verbose: 0, 1, or 2. Verbosity mode.
+ verbose: Integer. 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
@@ -1530,7 +1562,7 @@ class Model(Network):
`(x_val, y_val, val_sample_weights)` on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
- This will override `validation_split`.
+ `validation_data` will override `validation_split`.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
@@ -1553,17 +1585,20 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`.
- initial_epoch: Epoch at which to start training
+ initial_epoch: Integer.
+ Epoch at which to start training
(useful for resuming a previous training run).
- steps_per_epoch: Total number of steps (batches of samples)
+ steps_per_epoch: Integer or `None`.
+ Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. When training with input tensors such as
TensorFlow data tensors, the default `None` is equal to
- the number of unique samples in your dataset divided by
+ the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
validation_steps: Only relevant if `steps_per_epoch`
is specified. Total number of steps (batches of samples)
to validate before stopping.
+ **kwargs: Used for backwards compatibility.
Returns:
A `History` object. Its `History.history` attribute is
@@ -1572,12 +1607,21 @@ class Model(Network):
and validation metrics values (if applicable).
Raises:
+ RuntimeError: If the model was never compiled.
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
# Backwards compatibility
if batch_size is None and steps_per_epoch is None:
batch_size = 32
+ # Legacy support
+ if 'nb_epoch' in kwargs:
+ logging.warning(
+ 'The `nb_epoch` argument in `fit` '
+ 'has been renamed `epochs`.')
+ epochs = kwargs.pop('nb_epoch')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
if x is None and y is None and steps_per_epoch is None:
raise ValueError('If fitting from data tensors, '
'you should specify the `steps_per_epoch` '
@@ -1590,10 +1634,8 @@ class Model(Network):
class_weight=class_weight,
check_batch_axis=False,
batch_size=batch_size)
-
# Prepare validation data.
do_validation = False
- val_ins = []
if validation_data:
do_validation = True
if len(validation_data) == 2:
@@ -1657,8 +1699,9 @@ class Model(Network):
'val_' + n for n in out_labels
]
else:
- val_f = None
callback_metrics = copy.copy(out_labels)
+ val_f = None
+ val_ins = []
# Delegate logic to `_fit_loop`.
return self._fit_loop(
@@ -1694,14 +1737,14 @@ class Model(Network):
If input layers in the model are named, you can also pass a
dictionary mapping input names to Numpy arrays.
`x` can be `None` (default) if feeding from
- framework-native tensors (e.g. TensorFlow data tensors).
+ TensorFlow data tensors.
y: Numpy array of target (label) data
(if the model has a single output),
or list of Numpy arrays (if the model has multiple outputs).
If output layers in the model are named, you can also pass a
dictionary mapping output names to Numpy arrays.
`y` can be `None` (default) if feeding from
- framework-native tensors (e.g. TensorFlow data tensors).
+ TensorFlow data tensors.
batch_size: Integer or `None`.
Number of samples per evaluation step.
If unspecified, `batch_size` will default to 32.
@@ -1721,8 +1764,7 @@ class Model(Network):
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
- The default `None` is equal to the number of unique samples in
- your dataset divided by the batch size.
+ Ignored with the default value of `None`.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1731,7 +1773,7 @@ class Model(Network):
the display labels for the scalar outputs.
Raises:
- ValueError: In case of invalid arguments.
+ ValueError: in case of invalid arguments.
"""
# Backwards compatibility.
if batch_size is None and steps is None:
@@ -1890,6 +1932,9 @@ class Model(Network):
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
+
+ Raises:
+ ValueError: in case of invalid arguments.
"""
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight, check_batch_axis=True)
@@ -1937,8 +1982,7 @@ class Model(Network):
workers=1,
use_multiprocessing=False,
shuffle=True,
- initial_epoch=0,
- **kwargs):
+ initial_epoch=0):
"""Fits the model on data yielded batch-by-batch by a Python generator.
The generator is run in parallel to the model, for efficiency.
@@ -1950,22 +1994,31 @@ class Model(Network):
using `use_multiprocessing=True`.
Arguments:
- generator: A generator or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data when using multiprocessing.
+ generator: A generator or an instance of `Sequence`
+ (`keras.utils.Sequence`)
+ object in order to avoid duplicate data
+ when using multiprocessing.
The output of the generator must be either
- - a tuple (inputs, targets)
- - a tuple (inputs, targets, sample_weights).
- All arrays should contain the same number of samples.
+ - a tuple `(inputs, targets)`
+ - a tuple `(inputs, targets, sample_weights)`.
+ This tuple (a single output of the generator) makes a single batch.
+ Therefore, all arrays in this tuple must have the same length (equal
+ to the size of this batch). Different batches may have different
+ sizes.
+ For example, the last batch of the epoch is commonly smaller than
+ the
+ others, if the size of the dataset is not divisible by the batch
+ size.
The generator is expected to loop over its data
indefinitely. An epoch finishes when `steps_per_epoch`
batches have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
- be equal to the number of unique samples of your dataset
+ be equal to the number of samples of your dataset
divided by the batch size.
Optional for `Sequence`: if unspecified, will use
- `len(generator)` as a number of steps.
+ the `len(generator)` as a number of steps.
epochs: Integer, total number of iterations on the data.
verbose: Verbosity mode, 0, 1, or 2.
callbacks: List of callbacks to be called during training.
@@ -1977,27 +2030,28 @@ class Model(Network):
is a generator. Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
- `len(generator)` as a number of steps.
+ the `len(validation_data)` as a number of steps.
class_weight: Dictionary mapping class indices to a weight
for the class.
- max_queue_size: Maximum size for the generator queue.
+ max_queue_size: Integer. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
workers: Integer. Maximum number of processes to spin up
when using process based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
- use_multiprocessing: If True, use process based threading.
+ use_multiprocessing: Boolean. If True, use process based threading.
+ If unspecified, `workers` will default to False.
Note that because
this implementation relies on multiprocessing,
you should not pass
non picklable arguments to the generator
as they can't be passed
easily to children processes.
- shuffle: Whether to shuffle the data at the beginning of each
- epoch. Only used with instances of `Sequence`
- (`keras.utils.Sequence`).
+ shuffle: Whether to shuffle the order of the batches at
+ the beginning of each epoch. Only used with instances
+ of `Sequence` (keras.utils.Sequence).
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
- **kwargs: support for legacy arguments.
Returns:
A `History` object.
@@ -2023,19 +2077,6 @@ class Model(Network):
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
wait_time = 0.01 # in seconds
epoch = initial_epoch
@@ -2046,10 +2087,11 @@ class Model(Network):
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warning('Using a generator with `use_multiprocessing=True`'
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
- ' class.')
+ ' class.'))
if steps_per_epoch is None:
if is_sequence:
steps_per_epoch = len(generator)
@@ -2098,26 +2140,47 @@ class Model(Network):
})
callbacks.on_train_begin()
- if do_validation and not val_gen:
- if len(validation_data) == 2:
- val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
- val_sample_weight = None
- elif len(validation_data) == 3:
- val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
- else:
- raise ValueError('`validation_data` should be a tuple '
- '`(val_x, val_y, val_sample_weight)` '
- 'or `(val_x, val_y)`. Found: ' + str(validation_data))
- val_x, val_y, val_sample_weights = self._standardize_user_data(
- val_x, val_y, val_sample_weight)
- val_data = val_x + val_y + val_sample_weights
- if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
- val_data += [0.]
- for cbk in callbacks:
- cbk.validation_data = val_data
enqueuer = None
+ val_enqueuer = None
try:
+ if do_validation:
+ if val_gen:
+ if workers > 0:
+ if isinstance(validation_data, Sequence):
+ val_enqueuer = OrderedEnqueuer(
+ validation_data, use_multiprocessing=use_multiprocessing)
+ if validation_steps is None:
+ validation_steps = len(validation_data)
+ else:
+ val_enqueuer = GeneratorEnqueuer(
+ validation_data,
+ use_multiprocessing=use_multiprocessing,
+ wait_time=wait_time)
+ val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
+ validation_generator = val_enqueuer.get()
+ else:
+ validation_generator = validation_data
+ else:
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ else:
+ raise ValueError(
+ '`validation_data` should be a tuple '
+ '`(val_x, val_y, val_sample_weight)` '
+ 'or `(val_x, val_y)`. Found: ' + str(validation_data))
+ val_x, val_y, val_sample_weights = self._standardize_user_data(
+ val_x, val_y, val_sample_weight)
+ val_data = val_x + val_y + val_sample_weights
+ if self.uses_learning_phase and not isinstance(
+ K.learning_phase(), int):
+ val_data += [0.]
+ for cbk in callbacks:
+ cbk.validation_data = val_data
+
if workers > 0:
if is_sequence:
enqueuer = OrderedEnqueuer(
@@ -2135,6 +2198,8 @@ class Model(Network):
output_generator = generator
callback_model.stop_training = False
+ # Construct epoch logs.
+ epoch_logs = {}
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
steps_done = 0
@@ -2178,8 +2243,6 @@ class Model(Network):
callbacks.on_batch_end(batch_index, batch_logs)
- # Construct epoch logs.
- epoch_logs = {}
batch_index += 1
steps_done += 1
@@ -2187,11 +2250,7 @@ class Model(Network):
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(
- validation_data,
- validation_steps,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing)
+ validation_generator, validation_steps, workers=0)
else:
# No need for try/except because
# data has already been validated.
@@ -2216,8 +2275,12 @@ class Model(Network):
break
finally:
- if enqueuer is not None:
- enqueuer.stop()
+ try:
+ if enqueuer is not None:
+ enqueuer.stop()
+ finally:
+ if val_enqueuer is not None:
+ val_enqueuer.stop()
callbacks.on_train_end()
return self.history
@@ -2227,8 +2290,7 @@ class Model(Network):
steps=None,
max_queue_size=10,
workers=1,
- use_multiprocessing=False,
- **kwargs):
+ use_multiprocessing=False):
"""Evaluates the model on a data generator.
The generator should return the same kind of data
@@ -2256,7 +2318,6 @@ class Model(Network):
non picklable arguments to the generator
as they can't be passed
easily to children processes.
- **kwargs: support for legacy arguments.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -2265,22 +2326,12 @@ class Model(Network):
the display labels for the scalar outputs.
Raises:
+ ValueError: in case of invalid arguments.
+
+ Raises:
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
self._make_test_function()
steps_done = 0
@@ -2289,10 +2340,11 @@ class Model(Network):
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warning('Using a generator with `use_multiprocessing=True`'
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
- ' class.')
+ ' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
@@ -2368,8 +2420,7 @@ class Model(Network):
max_queue_size=10,
workers=1,
use_multiprocessing=False,
- verbose=0,
- **kwargs):
+ verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@@ -2377,9 +2428,9 @@ class Model(Network):
Arguments:
generator: Generator yielding batches of input samples
- or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data
- when using multiprocessing.
+ or an instance of Sequence (keras.utils.Sequence)
+ object in order to avoid duplicate data
+ when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
@@ -2397,7 +2448,6 @@ class Model(Network):
as they can't be passed
easily to children processes.
verbose: verbosity mode, 0 or 1.
- **kwargs: support for legacy arguments.
Returns:
Numpy array(s) of predictions.
@@ -2406,17 +2456,6 @@ class Model(Network):
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
-
self._make_predict_function()
steps_done = 0
@@ -2424,10 +2463,11 @@ class Model(Network):
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warn('Using a generator with `use_multiprocessing=True`'
- ' and multiple workers may duplicate your data.'
- ' Please consider using the`keras.utils.Sequence'
- ' class.')
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
+ ' and multiple workers may duplicate your data.'
+ ' Please consider using the`keras.utils.Sequence'
+ ' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
@@ -2498,6 +2538,6 @@ class Model(Network):
else:
return np.concatenate(all_outs[0])
if steps_done == 1:
- return [out for out in all_outs]
+ return [out[0] for out in all_outs]
else:
return [np.concatenate(out) for out in all_outs]
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 7650bfb6e8..5a033a04ad 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -28,6 +28,11 @@ from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective
from tensorflow.python.platform import test
+try:
+ import scipy.sparse as scipy_sparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ scipy_sparse = None
+
class TrainingTest(test.TestCase):
@@ -169,7 +174,7 @@ class TrainingTest(test.TestCase):
with self.assertRaises(ValueError):
model.train_on_batch({'input_a': input_a_np},
[output_d_np, output_e_np])
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=1,
@@ -177,7 +182,7 @@ class TrainingTest(test.TestCase):
verbose=0)
with self.assertRaises(ValueError):
model.train_on_batch([input_a_np], [output_d_np, output_e_np])
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.train_on_batch(1, [output_d_np, output_e_np])
with self.assertRaises(ValueError):
model.train_on_batch(input_a_np, [output_d_np, output_e_np])
@@ -312,6 +317,63 @@ class TrainingTest(test.TestCase):
model.compile(loss=None,
optimizer='rmsprop')
+ def test_training_on_sparse_data_with_dense_placeholders(self):
+ if scipy_sparse is None:
+ return
+
+ test_inputs = [
+ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
+ test_outputs = [
+ scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
+ in1 = keras.layers.Input(shape=(3,))
+ in2 = keras.layers.Input(shape=(3,))
+ out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
+ out2 = keras.layers.Dense(4, name='dense_1')(in2)
+ model = keras.Model([in1, in2], [out1, out2])
+ model.predict(test_inputs, batch_size=2)
+ model.compile('rmsprop', 'mse')
+ model.fit(test_inputs, test_outputs,
+ epochs=1, batch_size=2, validation_split=0.5)
+ model.evaluate(test_inputs, test_outputs, batch_size=2)
+
+ def test_that_trainable_disables_updates(self):
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ with self.test_session():
+ a = keras.layers.Input(shape=(4,))
+ layer = keras.layers.BatchNormalization(input_shape=(4,))
+ b = layer(a)
+ model = keras.Model(a, b)
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
+
+ layer.trainable = False
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
class LossWeightingTest(test.TestCase):
@@ -869,25 +931,6 @@ class TestGeneratorMethods(test.TestCase):
use_multiprocessing=False,
workers=0)
- # Test legacy API
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_q_size=10,
- workers=4,
- pickle_safe=True)
- model.predict_generator(custom_generator(),
- steps=5,
- max_q_size=10,
- workers=2,
- pickle_safe=True)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_q_size=10,
- workers=2,
- pickle_safe=True)
-
def test_generator_methods_with_sample_weights(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
@@ -960,7 +1003,7 @@ class TestGeneratorMethods(test.TestCase):
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index e4b9afd38a..ffbf77c4b8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -14,18 +14,18 @@
# ==============================================================================
"""Layers that act as activation functions.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class LeakyReLU(Layer):
@@ -61,6 +61,7 @@ class LeakyReLU(Layer):
base_config = super(LeakyReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -114,9 +115,9 @@ class PReLU(Layer):
else:
self.shared_axes = list(shared_axes)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- param_shape = input_shape[1:]
+ param_shape = list(input_shape[1:])
self.param_broadcast = [False] * len(param_shape)
if self.shared_axes is not None:
for i in self.shared_axes:
@@ -140,15 +141,13 @@ class PReLU(Layer):
def call(self, inputs, mask=None):
pos = K.relu(inputs)
if K.backend() == 'theano':
- neg = (K.pattern_broadcast(self.alpha, self.param_broadcast) *
- (inputs - K.abs(inputs)) * 0.5)
+ neg = (
+ K.pattern_broadcast(self.alpha, self.param_broadcast) *
+ (inputs - K.abs(inputs)) * 0.5)
else:
neg = -self.alpha * K.relu(-inputs)
return pos + neg
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {
'alpha_initializer': initializers.serialize(self.alpha_initializer),
@@ -159,6 +158,10 @@ class PReLU(Layer):
base_config = super(PReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class ELU(Layer):
"""Exponential Linear Unit.
@@ -188,14 +191,15 @@ class ELU(Layer):
def call(self, inputs):
return K.elu(inputs, self.alpha)
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'alpha': float(self.alpha)}
base_config = super(ELU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class ThresholdedReLU(Layer):
"""Thresholded Rectified Linear Unit.
@@ -223,12 +227,46 @@ class ThresholdedReLU(Layer):
self.theta = K.cast_to_floatx(theta)
def call(self, inputs, mask=None):
- return inputs * K.cast(inputs > self.theta, K.floatx())
+ return inputs * K.cast(K.greater(inputs, self.theta), K.floatx())
+
+ def get_config(self):
+ config = {'theta': float(self.theta)}
+ base_config = super(ThresholdedReLU, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
+
+class Softmax(Layer):
+ """Softmax activation function.
+
+ Input shape:
+ Arbitrary. Use the keyword argument `input_shape`
+ (tuple of integers, does not include the samples axis)
+ when using this layer as the first layer in a model.
+
+ Output shape:
+ Same shape as the input.
+
+ Arguments:
+ axis: Integer, axis along which the softmax normalization is applied.
+ """
+
+ def __init__(self, axis=-1, **kwargs):
+ super(Softmax, self).__init__(**kwargs)
+ self.supports_masking = True
+ self.axis = axis
+
+ def call(self, inputs):
+ return activations.softmax(inputs, axis=self.axis)
+
def get_config(self):
- config = {'theta': float(self.theta)}
- base_config = super(ThresholdedReLU, self).get_config()
+ config = {'axis': self.axis}
+ base_config = super(Softmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
index 91efab30ed..343b7949ac 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
@@ -56,6 +56,12 @@ class AdvancedActivationsTest(test.TestCase):
kwargs={'theta': 0.5},
input_shape=(2, 3, 4))
+ def test_softmax(self):
+ with self.test_session():
+ testing_utils.layer_test(keras.layers.Softmax,
+ kwargs={'axis': 1},
+ input_shape=(2, 3, 4))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 22496e8a76..2ee0732775 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -563,7 +563,7 @@ class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer):
return dict(list(base_config.items()) + list(config.items()))
-class Conv3DTranspose(tf_convolutional_layers.Conv3D, Layer):
+class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer):
"""Transposed convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
@@ -711,6 +711,144 @@ class Conv3DTranspose(tf_convolutional_layers.Conv3D, Layer):
return dict(list(base_config.items()) + list(config.items()))
+class SeparableConv1D(tf_convolutional_layers.SeparableConv1D, Layer):
+ """Depthwise separable 1D convolution.
+
+ This layer performs a depthwise convolution that acts separately on
+ channels, followed by a pointwise convolution that mixes channels.
+ If `use_bias` is True and a bias initializer is provided,
+ it adds a bias vector to the output.
+ It then optionally applies an activation function to produce the final output.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A single integer specifying the spatial
+ dimensions of the filters.
+ strides: A single integer specifying the strides
+ of the convolution.
+ Specifying any `stride` value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: A single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ depthwise_initializer: An initializer for the depthwise convolution kernel.
+ pointwise_initializer: An initializer for the pointwise convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ depthwise_regularizer: Optional regularizer for the depthwise
+ convolution kernel.
+ pointwise_regularizer: Optional regularizer for the pointwise
+ convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Optional regularizer function for the output.
+ depthwise_constraint: Optional projection function to be applied to the
+ depthwise kernel after being updated by an `Optimizer` (e.g. used for
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are
+ not safe to use when doing asynchronous distributed training.
+ pointwise_constraint: Optional projection function to be applied to the
+ pointwise kernel after being updated by an `Optimizer`.
+ bias_constraint: Optional projection function to be applied to the
+ bias after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=1,
+ padding='valid',
+ data_format=None,
+ dilation_rate=1,
+ depth_multiplier=1,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer='glorot_uniform',
+ pointwise_initializer='glorot_uniform',
+ bias_initializer='zeros',
+ depthwise_regularizer=None,
+ pointwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ pointwise_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ if data_format is None:
+ data_format = K.image_data_format()
+ super(SeparableConv1D, self).__init__(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activations.get(activation),
+ use_bias=use_bias,
+ depthwise_initializer=initializers.get(depthwise_initializer),
+ pointwise_initializer=initializers.get(pointwise_initializer),
+ bias_initializer=initializers.get(bias_initializer),
+ depthwise_regularizer=regularizers.get(depthwise_regularizer),
+ pointwise_regularizer=regularizers.get(pointwise_regularizer),
+ bias_regularizer=regularizers.get(bias_regularizer),
+ activity_regularizer=regularizers.get(activity_regularizer),
+ depthwise_constraint=constraints.get(depthwise_constraint),
+ pointwise_constraint=constraints.get(pointwise_constraint),
+ bias_constraint=constraints.get(bias_constraint),
+ **kwargs)
+
+ def get_config(self):
+ config = {
+ 'filters': self.filters,
+ 'kernel_size': self.kernel_size,
+ 'strides': self.strides,
+ 'padding': self.padding,
+ 'data_format': self.data_format,
+ 'dilation_rate': self.dilation_rate,
+ 'activation': activations.serialize(self.activation),
+ 'use_bias': self.use_bias,
+ 'depthwise_initializer':
+ initializers.serialize(self.depthwise_initializer),
+ 'pointwise_initializer':
+ initializers.serialize(self.pointwise_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'depthwise_regularizer':
+ regularizers.serialize(self.depthwise_regularizer),
+ 'pointwise_regularizer':
+ regularizers.serialize(self.pointwise_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'activity_regularizer':
+ regularizers.serialize(self.activity_regularizer),
+ 'depthwise_constraint':
+ constraints.serialize(self.depthwise_constraint),
+ 'pointwise_constraint':
+ constraints.serialize(self.pointwise_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
+ }
+ base_config = super(SeparableConv1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
"""Depthwise separable 2D convolution.
@@ -1663,6 +1801,7 @@ class Cropping3D(Layer):
Convolution1D = Conv1D
Convolution2D = Conv2D
Convolution3D = Conv3D
+SeparableConvolution1D = SeparableConv1D
SeparableConvolution2D = SeparableConv2D
Convolution2DTranspose = Conv2DTranspose
Convolution3DTranspose = Conv3DTranspose
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index 4f0e9fc691..565db19e41 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
from tensorflow.python.keras._impl.keras.utils import conv_utils
@@ -127,10 +127,10 @@ class ConvRecurrent2D(Recurrent):
self.input_spec = [InputSpec(ndim=5)]
self.state_spec = None
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[3]
cols = input_shape[4]
@@ -151,30 +151,28 @@ class ConvRecurrent2D(Recurrent):
dilation=self.dilation_rate[1])
if self.return_sequences:
if self.data_format == 'channels_first':
- output_shape = [input_shape[0], input_shape[1],
- self.filters, rows, cols]
+ output_shape = (input_shape[0], input_shape[1], self.filters, rows,
+ cols)
elif self.data_format == 'channels_last':
- output_shape = [input_shape[0], input_shape[1],
- rows, cols, self.filters]
+ output_shape = (input_shape[0], input_shape[1], rows, cols,
+ self.filters)
else:
if self.data_format == 'channels_first':
- output_shape = [input_shape[0], self.filters, rows, cols]
+ output_shape = (input_shape[0], self.filters, rows, cols)
elif self.data_format == 'channels_last':
- output_shape = [input_shape[0], rows, cols, self.filters]
+ output_shape = (input_shape[0], rows, cols, self.filters)
if self.return_state:
if self.data_format == 'channels_first':
- output_shapes = [output_shape] + [(input_shape[0],
- self.filters,
- rows,
- cols) for _ in range(2)]
+ output_shape = [output_shape] + [
+ (input_shape[0], self.filters, rows, cols) for _ in range(2)
+ ]
elif self.data_format == 'channels_last':
- output_shapes = [output_shape] + [(input_shape[0],
- rows,
- cols,
- self.filters) for _ in range(2)]
- return [tensor_shape.TensorShape(shape) for shape in output_shapes]
- return tensor_shape.TensorShape(output_shape)
+ output_shape = [output_shape] + [
+ (input_shape[0], rows, cols, self.filters) for _ in range(2)
+ ]
+
+ return output_shape
def get_config(self):
config = {
@@ -294,11 +292,6 @@ class ConvLSTM2D(ConvRecurrent2D):
Raises:
ValueError: in case of invalid constructor arguments.
- References:
- - [Convolutional LSTM Network: A Machine Learning Approach for
- Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
- The current implementation does not include the feedback loop on the
- cells output
"""
def __init__(self,
@@ -338,7 +331,6 @@ class ConvLSTM2D(ConvRecurrent2D):
return_sequences=return_sequences,
go_backwards=go_backwards,
stateful=stateful,
- activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -352,6 +344,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
@@ -361,13 +354,12 @@ class ConvLSTM2D(ConvRecurrent2D):
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
+ @shape_type_conversion
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
batch_size = input_shape[0] if self.stateful else None
self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:])
-
if self.stateful:
self.reset_states()
else:
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
index be7da6f2b4..39c9d4f0fb 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
@@ -311,6 +311,72 @@ class Conv3DTransposeTest(test.TestCase):
self.assertEqual(layer.bias.constraint, b_constraint)
+class SeparableConv1DTest(test.TestCase):
+
+ def test_separable_conv_1d(self):
+ num_samples = 2
+ filters = 6
+ stack_size = 3
+ length = 7
+ strides = 1
+
+ for padding in ['valid', 'same']:
+ for multiplier in [1, 2]:
+ if padding == 'same' and strides != 1:
+ continue
+
+ with self.test_session(use_gpu=True):
+ testing_utils.layer_test(
+ keras.layers.SeparableConv1D,
+ kwargs={
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'padding': padding,
+ 'strides': strides,
+ 'depth_multiplier': multiplier
+ },
+ input_shape=(num_samples, length, stack_size))
+
+ def test_separable_conv1d_regularizers(self):
+ kwargs = {
+ 'filters': 3,
+ 'kernel_size': 3,
+ 'padding': 'valid',
+ 'depthwise_regularizer': 'l2',
+ 'pointwise_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'strides': 1
+ }
+ with self.test_session(use_gpu=True):
+ layer = keras.layers.SeparableConv1D(**kwargs)
+ layer.build((None, 5, 2))
+ self.assertEqual(len(layer.losses), 3)
+ layer(keras.backend.variable(np.ones((1, 5, 2))))
+ self.assertEqual(len(layer.losses), 4)
+
+ def test_separable_conv1d_constraints(self):
+ d_constraint = lambda x: x
+ p_constraint = lambda x: x
+ b_constraint = lambda x: x
+
+ kwargs = {
+ 'filters': 3,
+ 'kernel_size': 3,
+ 'padding': 'valid',
+ 'pointwise_constraint': p_constraint,
+ 'depthwise_constraint': d_constraint,
+ 'bias_constraint': b_constraint,
+ 'strides': 1
+ }
+ with self.test_session(use_gpu=True):
+ layer = keras.layers.SeparableConv1D(**kwargs)
+ layer.build((None, 5, 2))
+ self.assertEqual(layer.depthwise_kernel.constraint, d_constraint)
+ self.assertEqual(layer.pointwise_kernel.constraint, p_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
+
+
class SeparableConv2DTest(test.TestCase):
def test_separable_conv_2d(self):
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 51c520be38..f8e31068f8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class Embedding(Layer):
@@ -58,13 +58,13 @@ class Embedding(Layer):
output_dim: int >= 0. Dimension of the dense embedding.
embeddings_initializer: Initializer for the `embeddings` matrix.
embeddings_regularizer: Regularizer function applied to
- the `embeddings` matrix.
+ the `embeddings` matrix.
embeddings_constraint: Constraint function applied to
- the `embeddings` matrix.
+ the `embeddings` matrix.
mask_zero: Whether or not the input value 0 is a special "padding"
value that should be masked out.
- This is useful when using recurrent layers,
- which may take variable length inputs.
+ This is useful when using recurrent layers
+ which may take variable length input.
If this is `True` then all subsequent layers
in the model need to support masking or an exception will be raised.
If mask_zero is set to True, as a consequence, index 0 cannot be
@@ -81,9 +81,6 @@ class Embedding(Layer):
Output shape:
3D tensor with shape: `(batch_size, sequence_length, output_dim)`.
- References:
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural
- Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -101,19 +98,19 @@ class Embedding(Layer):
kwargs['input_shape'] = (input_length,)
else:
kwargs['input_shape'] = (None,)
- super(Embedding, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(Embedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
self.input_length = input_length
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
@@ -129,10 +126,10 @@ class Embedding(Layer):
else:
return K.not_equal(inputs, 0)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.input_length is None:
- return tensor_shape.TensorShape(input_shape + [self.output_dim])
+ return input_shape + (self.output_dim,)
else:
# input_length can be tuple if input is 3D or higher
if isinstance(self.input_length, (list, tuple)):
@@ -149,8 +146,7 @@ class Embedding(Layer):
(str(self.input_length), str(input_shape)))
elif s1 is None:
in_lens[i] = s2
- return tensor_shape.TensorShape(
- (input_shape[0],) + tuple(in_lens) + (self.output_dim,))
+ return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
def call(self, inputs):
if K.dtype(inputs) != 'int32':
diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py
index 0a31b87fb5..b844b071e0 100644
--- a/tensorflow/python/keras/_impl/keras/layers/local.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
@@ -26,6 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils import conv_utils
@@ -98,8 +98,7 @@ class LocallyConnected1D(Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
- super(LocallyConnected1D, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(LocallyConnected1D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
@@ -114,12 +113,13 @@ class LocallyConnected1D(Layer):
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=3)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
input_dim = input_shape[2]
if input_dim is None:
raise ValueError('Axis 2 of input should be fully-defined. '
@@ -146,15 +146,14 @@ class LocallyConnected1D(Layer):
self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
self.built = True
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, self.filters])
+ return (input_shape[0], length, self.filters)
def call(self, inputs):
output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides)
-
if self.use_bias:
output = K.bias_add(output, self.bias)
if self.activation is not None:
@@ -163,20 +162,32 @@ class LocallyConnected1D(Layer):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(LocallyConnected1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -273,8 +284,7 @@ class LocallyConnected2D(Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
- super(LocallyConnected2D, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(LocallyConnected2D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
@@ -289,12 +299,13 @@ class LocallyConnected2D(Layer):
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=4)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
input_row, input_col = input_shape[1:-1]
input_filter = input_shape[3]
@@ -306,7 +317,6 @@ class LocallyConnected2D(Layer):
' a LocallyConnected2D layer '
'should be fully-defined, but layer received '
'the inputs shape ' + str(input_shape))
-
output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
self.padding, self.strides[0])
output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
@@ -337,33 +347,30 @@ class LocallyConnected2D(Layer):
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
self.built = True
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
elif self.data_format == 'channels_last':
rows = input_shape[1]
cols = input_shape[2]
+
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
self.padding, self.strides[0])
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
self.padding, self.strides[1])
if self.data_format == 'channels_first':
- return tensor_shape.TensorShape(
- [input_shape[0], self.filters, rows, cols])
+ return (input_shape[0], self.filters, rows, cols)
elif self.data_format == 'channels_last':
- return tensor_shape.TensorShape(
- [input_shape[0], rows, cols, self.filters])
+ return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv2d(inputs,
- self.kernel,
- self.kernel_size,
- self.strides,
+ output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides,
(self.output_row, self.output_col),
self.data_format)
+
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -372,21 +379,34 @@ class LocallyConnected2D(Layer):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'data_format': self.data_format,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'data_format':
+ self.data_format,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(LocallyConnected2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
index 76eb03cf27..38b0b30297 100644
--- a/tensorflow/python/keras/_impl/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -14,15 +14,15 @@
# ==============================================================================
# pylint: disable=not-callable
# pylint: disable=redefined-builtin
-"""Layers can merge several input tensors into a single output tensor.
+"""Layers that can merge several inputs into one.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine.topology import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class _Merge(Layer):
@@ -73,12 +73,13 @@ class _Merge(Layer):
output_shape.append(i)
else:
if i != j:
- raise ValueError('Operands could not be broadcast '
- 'together with shapes ' + str(shape1) + ' ' +
- str(shape2))
+ raise ValueError(
+ 'Operands could not be broadcast '
+ 'together with shapes ' + str(shape1) + ' ' + str(shape2))
output_shape.append(i)
return tuple(output_shape)
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
@@ -87,14 +88,13 @@ class _Merge(Layer):
raise ValueError('A merge layer should be called '
'on a list of at least 2 inputs. '
'Got ' + str(len(input_shape)) + ' inputs.')
- input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape]
batch_sizes = [s[0] for s in input_shape if s is not None]
batch_sizes = set(batch_sizes)
batch_sizes -= set([None])
if len(batch_sizes) > 1:
- raise ValueError('Can not merge tensors with different '
- 'batch sizes. Got tensors with shapes : ' +
- str(input_shape))
+ raise ValueError(
+ 'Can not merge tensors with different '
+ 'batch sizes. Got tensors with shapes : ' + str(input_shape))
if input_shape[0] is None:
output_shape = None
else:
@@ -111,9 +111,10 @@ class _Merge(Layer):
self._reshape_required = False
else:
self._reshape_required = True
- self.built = True
def call(self, inputs):
+ if not isinstance(inputs, list):
+ raise ValueError('A merge layer should be called ' 'on a list of inputs.')
if self._reshape_required:
reshaped_inputs = []
input_ndims = list(map(K.ndim, inputs))
@@ -172,6 +173,7 @@ class _Merge(Layer):
else:
return self._merge_function(inputs)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
output_shape = None
@@ -214,6 +216,22 @@ class Add(_Merge):
It takes as input a list of tensors,
all of the same shape, and returns
a single tensor (also of the same shape).
+
+ Examples:
+
+ ```python
+ import keras
+
+ input1 = keras.layers.Input(shape=(16,))
+ x1 = keras.layers.Dense(8, activation='relu')(input1)
+ input2 = keras.layers.Input(shape=(32,))
+ x2 = keras.layers.Dense(8, activation='relu')(input2)
+ added = keras.layers.Add()([x1, x2]) # equivalent to added =
+ keras.layers.add([x1, x2])
+
+ out = keras.layers.Dense(4)(added)
+ model = keras.models.Model(inputs=[input1, input2], outputs=out)
+ ```
"""
def _merge_function(self, inputs):
@@ -247,10 +265,17 @@ class Subtract(_Merge):
```
"""
+ @shape_type_conversion
+ def build(self, input_shape):
+ super(Subtract, self).build(input_shape)
+ if len(input_shape) != 2:
+ raise ValueError('A `Subtract` layer should be called '
+ 'on exactly 2 inputs')
+
def _merge_function(self, inputs):
if len(inputs) != 2:
- raise ValueError('`Subtract` layer should be called '
- 'on exactly 2 inputs. Received: %s' % inputs)
+ raise ValueError('A `Subtract` layer should be called '
+ 'on exactly 2 inputs')
return inputs[0] - inputs[1]
@@ -330,47 +355,43 @@ class Concatenate(_Merge):
super(Concatenate, self).__init__(**kwargs)
self.axis = axis
self.supports_masking = True
+ self._reshape_required = False
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
- if not (isinstance(input_shape, list) and len(input_shape) > 1):
- raise ValueError('`Concatenate` layer should be called '
- 'on a list containing at least two inputs')
+ if not isinstance(input_shape, list) or len(input_shape) < 2:
+ raise ValueError('A `Concatenate` layer should be called '
+ 'on a list of at least 2 inputs')
if all([shape is None for shape in input_shape]):
return
- reduced_inputs_shapes = [
- tensor_shape.TensorShape(shape).as_list() for shape in input_shape
- ]
+ reduced_inputs_shapes = [list(shape) for shape in input_shape]
shape_set = set()
for i in range(len(reduced_inputs_shapes)):
del reduced_inputs_shapes[i][self.axis]
shape_set.add(tuple(reduced_inputs_shapes[i]))
if len(shape_set) > 1:
- raise ValueError('`Concatenate` layer requires '
+ raise ValueError('A `Concatenate` layer requires '
'inputs with matching shapes '
'except for the concat axis. '
'Got inputs shapes: %s' % (input_shape))
- self.built = True
- def call(self, inputs):
- if not isinstance(inputs, list):
- raise ValueError('A `Concatenate` layer should be called '
- 'on a list of inputs.')
+ def _merge_function(self, inputs):
return K.concatenate(inputs, axis=self.axis)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list):
raise ValueError('A `Concatenate` layer should be called '
'on a list of inputs.')
input_shapes = input_shape
- output_shape = tensor_shape.TensorShape(input_shapes[0]).as_list()
+ output_shape = list(input_shapes[0])
for shape in input_shapes[1:]:
- shape = tensor_shape.TensorShape(shape).as_list()
if output_shape[self.axis] is None or shape[self.axis] is None:
output_shape[self.axis] = None
break
output_shape[self.axis] += shape[self.axis]
- return tensor_shape.TensorShape(output_shape)
+ return tuple(output_shape)
def compute_mask(self, inputs, mask=None):
if mask is None:
@@ -390,7 +411,7 @@ class Concatenate(_Merge):
masks = []
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
- # Input is unmasked. Append all 1s to masks
+ # Input is unmasked. Append all 1s to masks,
masks.append(K.ones_like(input_i, dtype='bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
@@ -441,14 +462,16 @@ class Dot(_Merge):
self.axes = axes
self.normalize = normalize
self.supports_masking = True
+ self._reshape_required = False
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `Dot` layer should be called '
'on a list of 2 inputs.')
- shape1 = tensor_shape.TensorShape(input_shape[0]).as_list()
- shape2 = tensor_shape.TensorShape(input_shape[1]).as_list()
+ shape1 = input_shape[0]
+ shape2 = input_shape[1]
if shape1 is None or shape2 is None:
return
if isinstance(self.axes, int):
@@ -462,9 +485,10 @@ class Dot(_Merge):
raise ValueError('Dimension incompatibility '
'%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
'Layer shapes: %s, %s' % (shape1, shape2))
- self.built = True
- def call(self, inputs):
+ def _merge_function(self, inputs):
+ if len(inputs) != 2:
+ raise ValueError('A `Dot` layer should be called ' 'on exactly 2 inputs')
x1 = inputs[0]
x2 = inputs[1]
if isinstance(self.axes, int):
@@ -485,12 +509,13 @@ class Dot(_Merge):
output = K.batch_dot(x1, x2, axes)
return output
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `Dot` layer should be called '
'on a list of 2 inputs.')
- shape1 = tensor_shape.TensorShape(input_shape[0]).as_list()
- shape2 = tensor_shape.TensorShape(input_shape[1]).as_list()
+ shape1 = list(input_shape[0])
+ shape2 = list(input_shape[1])
if isinstance(self.axes, int):
if self.axes < 0:
axes = [self.axes % len(shape1), self.axes % len(shape2)]
@@ -504,7 +529,7 @@ class Dot(_Merge):
output_shape = shape1 + shape2
if len(output_shape) == 1:
output_shape += [1]
- return tensor_shape.TensorShape(output_shape)
+ return tuple(output_shape)
def compute_mask(self, inputs, mask=None):
return None
@@ -527,6 +552,21 @@ def add(inputs, **kwargs):
Returns:
A tensor, the sum of the inputs.
+
+ Examples:
+
+ ```python
+ import keras
+
+ input1 = keras.layers.Input(shape=(16,))
+ x1 = keras.layers.Dense(8, activation='relu')(input1)
+ input2 = keras.layers.Input(shape=(32,))
+ x2 = keras.layers.Dense(8, activation='relu')(input2)
+ added = keras.layers.add([x1, x2])
+
+ out = keras.layers.Dense(4)(added)
+ model = keras.models.Model(inputs=[input1, input2], outputs=out)
+ ```
"""
return Add(**kwargs)(inputs)
diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
index 459f13145f..04fffcc384 100644
--- a/tensorflow/python/keras/_impl/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Layers for regularization models via the addition of noise.
+"""Layers that operate regularization via the addition of noise.
"""
from __future__ import absolute_import
from __future__ import division
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class GaussianNoise(Layer):
@@ -59,14 +60,15 @@ class GaussianNoise(Layer):
return K.in_train_phase(noised, inputs, training=training)
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'stddev': self.stddev}
base_config = super(GaussianNoise, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class GaussianDropout(Layer):
"""Apply multiplicative 1-centered Gaussian noise.
@@ -86,10 +88,6 @@ class GaussianDropout(Layer):
Output shape:
Same shape as input.
- References:
- - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- Srivastava, Hinton, et al.
- 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
"""
def __init__(self, rate, **kwargs):
@@ -108,14 +106,15 @@ class GaussianDropout(Layer):
return K.in_train_phase(noised, inputs, training=training)
return inputs
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'rate': self.rate}
base_config = super(GaussianDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class AlphaDropout(Layer):
"""Applies Alpha Dropout to the input.
@@ -140,8 +139,6 @@ class AlphaDropout(Layer):
Output shape:
Same shape as input.
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
@@ -157,26 +154,34 @@ class AlphaDropout(Layer):
def call(self, inputs, training=None):
if 0. < self.rate < 1.:
noise_shape = self._get_noise_shape(inputs)
- alpha = 1.6732632423543772848170429916717
- scale = 1.0507009873554804934193349852946
- def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):
+ def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: disable=missing-docstring
+ alpha = 1.6732632423543772848170429916717
+ scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale
- kept_idx = K.greater_equal(K.random_uniform(noise_shape, seed=seed),
- rate)
+
+ kept_idx = K.greater_equal(
+ K.random_uniform(noise_shape, seed=seed), rate)
kept_idx = K.cast(kept_idx, K.floatx())
- a = ((1 - rate) * (1 + rate * alpha_p ** 2)) ** -0.5
+
+ # Get affine transformation params
+ a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
b = -a * alpha_p * rate
+
+ # Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
+
+ # Do affine transformation
return a * x + b
return K.in_train_phase(dropped_inputs, inputs, training=training)
return inputs
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'rate': self.rate}
base_config = super(AlphaDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 6e38cf2f41..1b0f6cb6cf 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
-"""Recurrent layers.
+"""Recurrent layers and their base classes.
"""
from __future__ import absolute_import
from __future__ import division
@@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.platform import tf_logging as logging
@@ -109,6 +110,7 @@ class StackedRNNCells(Layer):
states += cell_states
return inputs, states
+ @shape_type_conversion
def build(self, input_shape):
for cell in self.cells:
if isinstance(cell, Layer):
@@ -117,7 +119,7 @@ class StackedRNNCells(Layer):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
- input_shape = (input_shape[0], input_shape[1], output_dim)
+ input_shape = (input_shape[0], output_dim)
self.built = True
def get_config(self):
@@ -262,8 +264,7 @@ class RNN(Layer):
(e.g. via the `input_shape` argument)
Input shape:
- 3D tensor with shape `(batch_size, timesteps, input_dim)`,
- (Optional) 2D tensors with shape `(batch_size, output_dim)`.
+ 3D tensor with shape `(batch_size, timesteps, input_dim)`.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -370,7 +371,6 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
- activity_regularizer=None,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -382,8 +382,7 @@ class RNN(Layer):
'an attribute `state_size` '
'(tuple of integers, '
'one integer per RNN state).')
- super(RNN, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(RNN, self).__init__(**kwargs)
self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
@@ -412,15 +411,16 @@ class RNN(Layer):
def states(self, states):
self._states = states
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if hasattr(self.cell.state_size, '__len__'):
- output_dim = self.cell.state_size[0]
+ state_size = self.cell.state_size
else:
- output_dim = self.cell.state_size
+ state_size = [self.cell.state_size]
+ output_dim = state_size[0]
if self.return_sequences:
output_shape = (input_shape[0], input_shape[1], output_dim)
@@ -428,11 +428,10 @@ class RNN(Layer):
output_shape = (input_shape[0], output_dim)
if self.return_state:
- state_shape = [(input_shape[0], output_dim) for _ in self.states]
- output_shape = [output_shape] + state_shape
+ state_shape = [(input_shape[0], dim) for dim in state_size]
+ return [output_shape] + state_shape
else:
- output_shape = output_shape
- return tensor_shape.TensorShape(output_shape)
+ return output_shape
def compute_mask(self, inputs, mask):
if isinstance(mask, list):
@@ -444,6 +443,7 @@ class RNN(Layer):
else:
return output_mask
+ @shape_type_conversion
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
@@ -454,7 +454,6 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
batch_size = input_shape[0] if self.stateful else None
input_dim = input_shape[-1]
@@ -478,9 +477,9 @@ class RNN(Layer):
# initial_state was passed in call, check compatibility
if [spec.shape[-1] for spec in self.state_spec] != state_size:
raise ValueError(
- 'An initial_state was passed that is not compatible with '
+ 'An `initial_state` was passed that is not compatible with '
'`cell.state_size`. Received `state_spec`={}; '
- 'However `cell.state_size` is '
+ 'however `cell.state_size` is '
'{}'.format(self.state_spec, self.cell.state_size))
else:
self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
@@ -610,7 +609,8 @@ class RNN(Layer):
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
- unroll=self.unroll)
+ unroll=self.unroll,
+ input_length=timesteps)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -625,6 +625,8 @@ class RNN(Layer):
# Properly set learning phase
if getattr(last_output, '_uses_learning_phase', False):
output._uses_learning_phase = True
+ for state in states:
+ state._uses_learning_phase = True
if self.return_state:
if not isinstance(states, (list, tuple)):
@@ -636,7 +638,7 @@ class RNN(Layer):
return output
def _standardize_args(self, inputs, initial_state, constants):
- """Standardize `__call__` arguments to a single list of tensor inputs.
+ """Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
@@ -688,7 +690,7 @@ class RNN(Layer):
'a `batch_input_shape` '
'argument to your first layer.\n'
'- If using the functional API, specify '
- 'the time dimension by passing a '
+ 'the batch size by passing a '
'`batch_shape` argument to your Input layer.')
# initialize state if None
if self.states[0] is None:
@@ -788,36 +790,26 @@ class SimpleRNNCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -865,6 +857,7 @@ class SimpleRNNCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
@@ -889,33 +882,21 @@ class SimpleRNNCell(Layer):
self.bias = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
prev_output = states[0]
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training)
+
dp_mask = self._dropout_mask
rec_dp_mask = self._recurrent_dropout_mask
@@ -938,45 +919,68 @@ class SimpleRNNCell(Layer):
output._uses_learning_phase = True
return output, [output]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout
+ }
+ base_config = super(SimpleRNNCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class SimpleRNN(RNN):
"""Fully-connected RNN where the output is to be fed back to input.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1050,12 +1054,12 @@ class SimpleRNN(RNN):
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
- activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(SimpleRNN, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -1117,25 +1121,36 @@ class SimpleRNN(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout
}
base_config = super(SimpleRNN, self).get_config()
del base_config['cell']
@@ -1153,39 +1168,28 @@ class GRUCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1243,6 +1247,7 @@ class GRUCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -1286,38 +1291,24 @@ class GRUCell(Layer):
self.bias_h = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(3)
- ]
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(3)
- ]
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training,
+ count=3)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training,
+ count=3)
+
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
@@ -1381,51 +1372,76 @@ class GRUCell(Layer):
h._uses_learning_phase = True
return h, [h]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'recurrent_activation':
+ activations.serialize(self.recurrent_activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
+ }
+ base_config = super(GRUCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class GRU(RNN):
- # pylint: disable=line-too-long
"""Gated Recurrent Unit - Cho et al.
2014.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1455,12 +1471,7 @@ class GRU(RNN):
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
- References:
- - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259)
- - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
"""
- # pylint: enable=line-too-long
def __init__(self,
units,
@@ -1518,8 +1529,8 @@ class GRU(RNN):
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(GRU, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -1589,28 +1600,40 @@ class GRU(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(GRU, self).get_config()
del base_config['cell']
@@ -1628,44 +1651,33 @@ class LSTMCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1725,6 +1737,7 @@ class LSTMCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -1784,36 +1797,22 @@ class LSTMCell(Layer):
self.bias_o = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training,
+ count=4)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training,
+ count=4)
+
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
@@ -1887,54 +1886,81 @@ class LSTMCell(Layer):
h._uses_learning_phase = True
return h, [h, c]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'recurrent_activation':
+ activations.serialize(self.recurrent_activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias':
+ self.unit_forget_bias,
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
+ }
+ base_config = super(LSTMCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class LSTM(RNN):
- # pylint: disable=line-too-long
"""Long-Short Term Memory layer - Hochreiter 1997.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the recurrent state..
+ bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1964,13 +1990,7 @@ class LSTM(RNN):
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
- References:
- - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
- - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
"""
- # pylint: enable=line-too-long
def __init__(self,
units,
@@ -2030,8 +2050,8 @@ class LSTM(RNN):
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(LSTM, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -2105,29 +2125,42 @@ class LSTM(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'unit_forget_bias': self.unit_forget_bias,
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias':
+ self.unit_forget_bias,
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(LSTM, self).get_config()
del base_config['cell']
@@ -2140,6 +2173,23 @@ class LSTM(RNN):
return cls(**config)
+def _generate_dropout_ones(inputs, dims):
+ return K.ones((K.shape(inputs)[0], dims))
+
+
+def _generate_dropout_mask(ones, rate, training=None, count=1):
+
+ def dropped_inputs():
+ return K.dropout(ones, rate)
+
+ if count > 1:
+ return [
+ K.in_train_phase(dropped_inputs, ones, training=training)
+ for _ in range(count)
+ ]
+ return K.in_train_phase(dropped_inputs, ones, training=training)
+
+
class Recurrent(Layer):
"""Deprecated abstract base class for recurrent layers.
@@ -2266,6 +2316,7 @@ class Recurrent(Layer):
self.dropout = 0
self.recurrent_dropout = 0
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
index 7dc4c1db9b..a1407a24ea 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -392,6 +392,105 @@ class RNNTest(test.TestCase):
self.assertEqual(len(layer.trainable_weights), 3)
self.assertEqual(len(layer.non_trainable_weights), 0)
+ def test_state_reuse_with_dropout(self):
+ layer_class = keras.layers.SimpleRNN
+ embedding_dim = 4
+ units = 3
+ timesteps = 2
+ num_samples = 2
+
+ with self.test_session():
+ input1 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
+ layer = layer_class(units,
+ return_state=True,
+ return_sequences=True,
+ dropout=0.2)
+ state = layer(input1)[1:]
+
+ input2 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
+ output = layer_class(units)(input2, initial_state=state)
+ model = keras.Model([input1, input2], output)
+
+ inputs = [np.random.random((num_samples, timesteps, embedding_dim)),
+ np.random.random((num_samples, timesteps, embedding_dim))]
+ model.predict(inputs)
+
+ def test_builtin_rnn_cell_serialization(self):
+ for cell_class in [keras.layers.SimpleRNNCell,
+ keras.layers.GRUCell,
+ keras.layers.LSTMCell]:
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((None, 5))
+ cell = cell_class(32)
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # Test stacking.
+ cells = [cell_class(8),
+ cell_class(12),
+ cell_class(32)]
+ layer = keras.layers.RNN(cells)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+
+ # Test stacked RNN serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ def test_stacked_rnn_dropout(self):
+ cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
+ keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
+ layer = keras.layers.RNN(cells)
+
+ with self.test_session():
+ x = keras.Input((None, 5))
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile('sgd', 'mse')
+ x_np = np.random.random((6, 5, 5))
+ y_np = np.random.random((6, 3))
+ model.train_on_batch(x_np, y_np)
+
+ def test_stacked_rnn_compute_output_shape(self):
+ cells = [keras.layers.LSTMCell(3),
+ keras.layers.LSTMCell(6)]
+ embedding_dim = 4
+ timesteps = 2
+ layer = keras.layers.RNN(cells, return_state=True, return_sequences=True)
+ output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
+ expected_output_shape = [(None, timesteps, 6),
+ (None, 6),
+ (None, 6),
+ (None, 3),
+ (None, 3)]
+ self.assertEqual(
+ [tuple(o.as_list()) for o in output_shape],
+ expected_output_shape)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 452801b656..3667956f80 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.layers import utils as tf_layers_util
@@ -291,6 +292,7 @@ class Bidirectional(Wrapper):
self.backward_layer.initial_weights = weights[nw // 2:]
self.stateful = layer.stateful
self.return_sequences = layer.return_sequences
+ self.return_state = layer.return_state
self.supports_masking = True
def get_weights(self):
@@ -301,27 +303,54 @@ class Bidirectional(Wrapper):
self.forward_layer.set_weights(weights[:nw // 2])
self.backward_layer.set_weights(weights[nw // 2:])
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
- if self.merge_mode in ['sum', 'ave', 'mul']:
- return self.forward_layer.compute_output_shape(input_shape)
- elif self.merge_mode == 'concat':
- shape = self.forward_layer.compute_output_shape(input_shape).as_list()
- shape[-1] *= 2
- return tensor_shape.TensorShape(shape)
+ output_shape = tuple(self.forward_layer.compute_output_shape(
+ input_shape).as_list())
+ if self.return_state:
+ state_shape = output_shape[1:]
+ output_shape = output_shape[0]
+
+ if self.merge_mode == 'concat':
+ output_shape = list(output_shape)
+ output_shape[-1] *= 2
+ output_shape = tuple(output_shape)
elif self.merge_mode is None:
- shape = self.forward_layer.compute_output_shape(input_shape)
- return [shape, copy.copy(shape)]
+ output_shape = [output_shape, copy.copy(output_shape)]
- def call(self, inputs, training=None, mask=None):
+ if self.return_state:
+ if self.merge_mode is None:
+ return output_shape + state_shape + copy.copy(state_shape)
+ return [output_shape] + state_shape + copy.copy(state_shape)
+ return output_shape
+
+ def call(self, inputs, training=None, mask=None, initial_state=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
- y = self.forward_layer.call(inputs, **kwargs)
- y_rev = self.backward_layer.call(inputs, **kwargs)
+ if initial_state is not None and has_arg(self.layer.call, 'initial_state'):
+ if not isinstance(initial_state, list):
+ raise ValueError(
+ 'When passing `initial_state` to a Bidirectional RNN, the state '
+ 'should be a list containing the states of the underlying RNNs. '
+ 'Found: ' + str(initial_state))
+ forward_state = initial_state[:len(initial_state) // 2]
+ backward_state = initial_state[len(initial_state) // 2:]
+ y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs)
+ y_rev = self.backward_layer.call(
+ inputs, initial_state=backward_state, **kwargs)
+ else:
+ y = self.forward_layer.call(inputs, **kwargs)
+ y_rev = self.backward_layer.call(inputs, **kwargs)
+
+ if self.return_state:
+ states = y[1:] + y_rev[1:]
+ y = y[0]
+ y_rev = y_rev[0]
+
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
@@ -343,6 +372,11 @@ class Bidirectional(Wrapper):
out._uses_learning_phase = True
else:
output._uses_learning_phase = True
+
+ if self.return_state:
+ if self.merge_mode is None:
+ return output + states
+ return [output] + states
return output
def reset_states(self):
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
index 0866c4b0ae..f48c8919a1 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
@@ -238,6 +238,131 @@ class BidirectionalTest(test.TestCase):
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
+ def test_Bidirectional_merged_value(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+ x = [np.random.rand(samples, timesteps, dim)]
+
+ with self.test_session():
+ for merge_mode in ['sum', 'mul', 'ave', 'concat', None]:
+ if merge_mode == 'sum':
+ merge_func = lambda y, y_rev: y + y_rev
+ elif merge_mode == 'mul':
+ merge_func = lambda y, y_rev: y * y_rev
+ elif merge_mode == 'ave':
+ merge_func = lambda y, y_rev: (y + y_rev) / 2
+ elif merge_mode == 'concat':
+ merge_func = lambda y, y_rev: np.concatenate((y, y_rev), axis=-1)
+ else:
+ merge_func = lambda y, y_rev: [y, y_rev]
+
+ # basic case
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_sequences=True), merge_mode=merge_mode)
+ f_merged = keras.backend.function([inputs], _to_list(layer(inputs)))
+ f_forward = keras.backend.function([inputs],
+ [layer.forward_layer.call(inputs)])
+ f_backward = keras.backend.function(
+ [inputs],
+ [keras.backend.reverse(layer.backward_layer.call(inputs), 1)])
+
+ y_merged = f_merged(x)
+ y_expected = _to_list(merge_func(f_forward(x)[0], f_backward(x)[0]))
+ assert len(y_merged) == len(y_expected)
+ for x1, x2 in zip(y_merged, y_expected):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ # test return_state
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_state=True), merge_mode=merge_mode)
+ f_merged = keras.backend.function([inputs], layer(inputs))
+ f_forward = keras.backend.function([inputs],
+ layer.forward_layer.call(inputs))
+ f_backward = keras.backend.function([inputs],
+ layer.backward_layer.call(inputs))
+ n_states = len(layer.layer.states)
+
+ y_merged = f_merged(x)
+ y_forward = f_forward(x)
+ y_backward = f_backward(x)
+ y_expected = _to_list(merge_func(y_forward[0], y_backward[0]))
+ assert len(y_merged) == len(y_expected) + n_states * 2
+ for x1, x2 in zip(y_merged, y_expected):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ y_merged = y_merged[-n_states * 2:]
+ y_forward = y_forward[-n_states:]
+ y_backward = y_backward[-n_states:]
+ for state_birnn, state_inner in zip(y_merged, y_forward + y_backward):
+ self.assertAllClose(state_birnn, state_inner, atol=1e-5)
+
+ def test_Bidirectional_dropout(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+ merge_mode = 'sum'
+ x = [np.random.rand(samples, timesteps, dim)]
+
+ with self.test_session():
+ inputs = keras.Input((timesteps, dim))
+ wrapped = keras.layers.Bidirectional(
+ rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode)
+ outputs = _to_list(wrapped(inputs, training=True))
+ assert all(not getattr(x, '_uses_learning_phase') for x in outputs)
+
+ inputs = keras.Input((timesteps, dim))
+ wrapped = keras.layers.Bidirectional(
+ rnn(units, dropout=0.2, return_state=True), merge_mode=merge_mode)
+ outputs = _to_list(wrapped(inputs))
+ assert all(x._uses_learning_phase for x in outputs)
+
+ model = keras.Model(inputs, outputs)
+ assert model.uses_learning_phase
+ y1 = _to_list(model.predict(x))
+ y2 = _to_list(model.predict(x))
+ for x1, x2 in zip(y1, y2):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ def test_Bidirectional_state_reuse(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+
+ with self.test_session():
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_state=True, return_sequences=True))
+ outputs = layer(inputs)
+ output, state = outputs[0], outputs[1:]
+
+ # test passing invalid initial_state: passing a tensor
+ with self.assertRaises(ValueError):
+ output = keras.layers.Bidirectional(
+ rnn(units))(output, initial_state=state[0])
+
+ # test valid usage: passing a list
+ output = keras.layers.Bidirectional(
+ rnn(units))(output, initial_state=state)
+ model = keras.Model(inputs, output)
+ inputs = np.random.rand(samples, timesteps, dim)
+ outputs = model.predict(inputs)
+
+
+def _to_list(ls):
+ if isinstance(ls, list):
+ return ls
+ else:
+ return [ls]
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py
index 1d6319abb1..fe0ef54360 100644
--- a/tensorflow/python/keras/_impl/keras/losses.py
+++ b/tensorflow/python/keras/_impl/keras/losses.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Built-in Keras loss functions.
+# pylint: disable=unused-import
+"""Built-in loss functions.
"""
from __future__ import absolute_import
from __future__ import division
@@ -34,7 +35,6 @@ def mean_absolute_error(y_true, y_pred):
def mean_absolute_percentage_error(y_true, y_pred):
- # Equivalent to MAE, but sometimes easier to interpret.
diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None))
return 100. * K.mean(diff, axis=-1)
@@ -56,10 +56,24 @@ def hinge(y_true, y_pred):
def categorical_hinge(y_true, y_pred):
pos = K.sum(y_true * y_pred, axis=-1)
neg = K.max((1. - y_true) * y_pred, axis=-1)
- return K.maximum(neg - pos + 1., 0.)
+ return K.maximum(0., neg - pos + 1.)
def logcosh(y_true, y_pred):
+ """Logarithm of the hyperbolic cosine of the prediction error.
+
+ `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
+ to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
+ like the mean squared error, but will not be so strongly affected by the
+ occasional wildly incorrect prediction.
+
+ Arguments:
+ y_true: tensor of true targets.
+ y_pred: tensor of predicted targets.
+
+ Returns:
+ Tensor with one scalar loss entry per sample.
+ """
def _logcosh(x):
return x + K.softplus(-2. * x) - K.log(2.)
diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py
index 202048f26d..3c18e68260 100644
--- a/tensorflow/python/keras/_impl/keras/metrics.py
+++ b/tensorflow/python/keras/_impl/keras/metrics.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Built-in Keras metrics functions.
+# pylint: disable=unused-import
+"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
@@ -21,7 +22,6 @@ from __future__ import print_function
import six
from tensorflow.python.keras._impl.keras import backend as K
-# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.losses import binary_crossentropy
from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy
from tensorflow.python.keras._impl.keras.losses import cosine_proximity
@@ -35,7 +35,6 @@ from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_
from tensorflow.python.keras._impl.keras.losses import poisson
from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras._impl.keras.losses import squared_hinge
-# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
@@ -60,8 +59,8 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(K.in_top_k(y_pred,
- K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
+ return K.mean(
+ K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py
index e262cc8c8e..9cd547200d 100644
--- a/tensorflow/python/keras/_impl/keras/models.py
+++ b/tensorflow/python/keras/_impl/keras/models.py
@@ -492,13 +492,13 @@ class Sequential(Model):
# to the input layer we just created.
layer(x)
- if len(layer.inbound_nodes[-1].output_tensors) != 1:
+ if len(layer._inbound_nodes[-1].output_tensors) != 1:
raise ValueError('All layers in a Sequential model '
'should have a single output tensor. '
'For multi-output layers, '
'use the functional API.')
- self.outputs = [layer.inbound_nodes[-1].output_tensors[0]]
+ self.outputs = [layer._inbound_nodes[-1].output_tensors[0]]
self.inputs = topology.get_source_inputs(self.outputs[0])
# We create an input node, which we will keep updated
diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py
index edfc0ce0eb..04017e4b28 100644
--- a/tensorflow/python/keras/_impl/keras/models_test.py
+++ b/tensorflow/python/keras/_impl/keras/models_test.py
@@ -340,6 +340,35 @@ class TestSequential(test.TestCase):
inner_model.trainable = True
self.assertEqual(len(model.trainable_weights), 4)
+ def test_sequential_update_disabling(self):
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.BatchNormalization(input_shape=(4,)))
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+ assert not model.model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+ assert model.model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
+
class TestModelCloning(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
index a08073fa86..e47987aadc 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras optimizer classes (will eventually be replaced with core optimizers).
+# pylint: disable=invalid-name
+"""Built-in optimizer classes.
"""
from __future__ import absolute_import
from __future__ import division
@@ -121,9 +122,9 @@ class Optimizer(object):
param_values = K.batch_get_value(params)
for pv, p, w in zip(param_values, params, weights):
if pv.shape != w.shape:
- raise ValueError('Optimizer weight shape ' + str(pv.shape) +
- ' not compatible with '
- 'provided weight shape ' + str(w.shape))
+ raise ValueError(
+ 'Optimizer weight shape ' + str(pv.shape) + ' not compatible with '
+ 'provided weight shape ' + str(w.shape))
weight_value_tuples.append((p, w))
K.batch_set_value(weight_value_tuples)
@@ -156,7 +157,8 @@ class SGD(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
- momentum: float >= 0. Parameter updates momentum.
+ momentum: float >= 0. Parameter that accelerates SGD
+ in the relevant direction and dampens oscillations.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""
@@ -177,9 +179,8 @@ class SGD(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
-
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
# momentum
shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
@@ -224,32 +225,33 @@ class RMSprop(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
rho: float >= 0.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
+
"""
- def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
super(RMSprop, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.rho = K.variable(rho, name='rho')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- accumulators = [
- K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params
- ]
+ accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
self.weights = accumulators
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
# update accumulator
@@ -283,20 +285,19 @@ class Adagrad(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
- epsilon: float >= 0.
+ epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adaptive Subgradient Methods for Online Learning and Stochastic
- Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""
- def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
super(Adagrad, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
@@ -309,8 +310,8 @@ class Adagrad(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
new_a = a + K.square(g) # update accumulator
@@ -344,20 +345,19 @@ class Adadelta(Optimizer):
lr: float >= 0. Learning rate.
It is recommended to leave it at the default value.
rho: float >= 0.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adadelta - an adaptive learning rate
- method](http://arxiv.org/abs/1212.5701)
"""
- def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
super(Adadelta, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.rho = rho
self.epsilon = epsilon
self.initial_decay = decay
@@ -372,8 +372,8 @@ class Adadelta(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
# update accumulator
@@ -415,20 +415,21 @@ class Adam(Optimizer):
lr: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
+ amsgrad: boolean. Whether to apply the AMSGrad variant of this
+ algorithm from the paper "On the Convergence of Adam and
+ Beyond".
- References:
- - [Adam - A Method for Stochastic
- Optimization](http://arxiv.org/abs/1412.6980v8)
"""
def __init__(self,
lr=0.001,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
decay=0.,
+ amsgrad=False,
**kwargs):
super(Adam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
@@ -437,8 +438,11 @@ class Adam(Optimizer):
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
+ self.amsgrad = amsgrad
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
@@ -446,21 +450,30 @@ class Adam(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
- lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
- (1. - K.pow(self.beta_1, t)))
+ lr_t = lr * (
+ K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
- self.weights = [self.iterations] + ms + vs
+ if self.amsgrad:
+ vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
+ else:
+ vhats = [K.zeros(1) for _ in params]
+ self.weights = [self.iterations] + ms + vs + vhats
- for p, g, m, v in zip(params, grads, ms, vs):
+ for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
- p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
+ if self.amsgrad:
+ vhat_t = K.maximum(vhat, v_t)
+ p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
+ self.updates.append(K.update(vhat, vhat_t))
+ else:
+ p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
@@ -479,7 +492,8 @@ class Adam(Optimizer):
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
- 'epsilon': self.epsilon
+ 'epsilon': self.epsilon,
+ 'amsgrad': self.amsgrad
}
base_config = super(Adam, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -494,19 +508,16 @@ class Adamax(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adam - A Method for Stochastic
- Optimization](http://arxiv.org/abs/1412.6980v8)
"""
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
decay=0.,
**kwargs):
super(Adamax, self).__init__(**kwargs)
@@ -516,6 +527,8 @@ class Adamax(Optimizer):
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
@@ -525,8 +538,8 @@ class Adamax(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
lr_t = lr / (1. - K.pow(self.beta_1, t))
@@ -580,19 +593,15 @@ class Nadam(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
- References:
- - [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- - [On the importance of initialization and momentum in deep
- learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
"""
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
schedule_decay=0.004,
**kwargs):
super(Nadam, self).__init__(**kwargs)
@@ -602,12 +611,15 @@ class Nadam(Optimizer):
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
+
t = K.cast(self.iterations, K.floatx()) + 1
# Due to the recommendations in [2], i.e. warming momentum schedule
@@ -691,7 +703,6 @@ class TFOptimizer(Optimizer):
# Aliases.
-# pylint: disable=invalid-name
sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
@@ -700,8 +711,6 @@ adam = Adam
adamax = Adamax
nadam = Nadam
-# pylint: enable=invalid-name
-
def serialize(optimizer):
return serialize_keras_object(optimizer)
diff --git a/tensorflow/python/keras/_impl/keras/optimizers_test.py b/tensorflow/python/keras/_impl/keras/optimizers_test.py
index 6e9e4e6c99..57636afbf0 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers_test.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers_test.py
@@ -102,6 +102,7 @@ class KerasOptimizersTest(test.TestCase):
with self.test_session():
_test_optimizer(keras.optimizers.Adam())
_test_optimizer(keras.optimizers.Adam(decay=1e-3))
+ _test_optimizer(keras.optimizers.Adam(amsgrad=True))
def test_adamax(self):
with self.test_session():
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
index 82441de592..db1fdd4e6b 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Fairly basic set of tools for real-time data augmentation on image data.
Can easily be extended to include new transformations,
@@ -28,25 +29,22 @@ import re
import threading
import numpy as np
-from six.moves import range # pylint: disable=redefined-builtin
-
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.platform import tf_logging as logging
-
-# pylint: disable=g-import-not-at-top
-try:
- from PIL import Image as pil_image
-except ImportError:
- pil_image = None
try:
from scipy import linalg
import scipy.ndimage as ndi
except ImportError:
linalg = None
ndi = None
-# pylint: enable=g-import-not-at-top
+
+
+try:
+ from PIL import Image as pil_image
+except ImportError:
+ pil_image = None
if pil_image is not None:
_PIL_INTERPOLATION_METHODS = {
@@ -88,7 +86,7 @@ def random_rotation(x,
Returns:
Rotated Numpy image tensor.
"""
- theta = np.pi / 180 * np.random.uniform(-rg, rg)
+ theta = np.deg2rad(np.random.uniform(-rg, rg))
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0], [0, 0, 1]])
@@ -145,7 +143,7 @@ def random_shear(x,
Arguments:
x: Input tensor. Must be 3D.
- intensity: Transformation intensity.
+ intensity: Transformation intensity in degrees.
row_axis: Index of axis for rows in the input tensor.
col_axis: Index of axis for columns in the input tensor.
channel_axis: Index of axis for channels in the input tensor.
@@ -158,7 +156,7 @@ def random_shear(x,
Returns:
Sheared Numpy image tensor.
"""
- shear = np.random.uniform(-intensity, intensity)
+ shear = np.deg2rad(np.random.uniform(-intensity, intensity))
shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0],
[0, 0, 1]])
@@ -188,8 +186,10 @@ def random_zoom(x,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
+
Returns:
Zoomed Numpy image tensor.
+
Raises:
ValueError: if `zoom_range` isn't a tuple.
"""
@@ -366,7 +366,7 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
grayscale: Boolean, whether to load the image as grayscale.
target_size: Either `None` (default to original size)
or tuple of ints `(img_height, img_width)`.
- interpolation: Interpolation method used to resample the image if the
+ interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
@@ -394,11 +394,9 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
- raise ValueError(
- 'Invalid interpolation method {} specified. Supported '
- 'methods are {}'.format(
- interpolation,
- ', '.join(_PIL_INTERPOLATION_METHODS.keys())))
+ raise ValueError('Invalid interpolation method {} specified. Supported '
+ 'methods are {}'.format(interpolation, ', '.join(
+ _PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img
@@ -407,7 +405,8 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
return [
os.path.join(root, f)
- for root, _, files in os.walk(directory) for f in files
+ for root, _, files in os.walk(directory)
+ for f in files
if re.match(r'([\w]+\.(?:' + ext + '))', f)
]
@@ -423,9 +422,9 @@ class ImageDataGenerator(object):
zca_whitening: apply ZCA whitening.
zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
rotation_range: degrees (0 to 180).
- width_shift_range: fraction of total width.
- height_shift_range: fraction of total height.
- shear_range: shear intensity (shear angle in radians).
+ width_shift_range: fraction of total width, if < 1, or pixels if >= 1.
+ height_shift_range: fraction of total height, if < 1, or pixels if >= 1.
+ shear_range: shear intensity (shear angle in degrees).
zoom_range: amount of zoom. if scalar z, zoom will be randomly picked
in the range [1-z, 1+z]. A sequence of two can be passed instead
to select this range.
@@ -433,6 +432,12 @@ class ImageDataGenerator(object):
fill_mode: points outside the boundaries are filled according to the
given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default
is 'nearest'.
+ Points outside the boundaries of the input are filled according to the
+ given mode:
+ 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
+ 'nearest': aaaaaaaa|abcd|dddddddd
+ 'reflect': abcddcba|abcd|dcbaabcd
+ 'wrap': abcdabcd|abcd|abcdabcd
cval: value used for points outside the boundaries when fill_mode is
'constant'. Default is 0.
horizontal_flip: whether to randomly flip images horizontally.
@@ -522,6 +527,32 @@ class ImageDataGenerator(object):
raise ValueError('`zoom_range` should be a float or '
'a tuple or list of two floats. '
'Received arg: ', zoom_range)
+ if zca_whitening:
+ if not featurewise_center:
+ self.featurewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`zca_whitening`, which overrides '
+ 'setting of `featurewise_center`.')
+ if featurewise_std_normalization:
+ self.featurewise_std_normalization = False
+ logging.warning('This ImageDataGenerator specifies '
+ '`zca_whitening` '
+ 'which overrides setting of'
+ '`featurewise_std_normalization`.')
+ if featurewise_std_normalization:
+ if not featurewise_center:
+ self.featurewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`featurewise_std_normalization`, '
+ 'which overrides setting of '
+ '`featurewise_center`.')
+ if samplewise_std_normalization:
+ if not samplewise_center:
+ self.samplewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`samplewise_std_normalization`, '
+ 'which overrides setting of '
+ '`samplewise_center`.')
def flow(self,
x,
@@ -591,7 +622,7 @@ class ImageDataGenerator(object):
if self.samplewise_center:
x -= np.mean(x, keepdims=True)
if self.samplewise_std_normalization:
- x /= np.std(x, keepdims=True) + 1e-7
+ x /= (np.std(x, keepdims=True) + K.epsilon())
if self.featurewise_center:
if self.mean is not None:
@@ -603,7 +634,7 @@ class ImageDataGenerator(object):
'first by calling `.fit(numpy_data)`.')
if self.featurewise_std_normalization:
if self.std is not None:
- x /= (self.std + 1e-7)
+ x /= (self.std + K.epsilon())
else:
logging.warning('This ImageDataGenerator specifies '
'`featurewise_std_normalization`, but it hasn\'t '
@@ -636,7 +667,6 @@ class ImageDataGenerator(object):
"""
if ndi is None:
raise ImportError('Scipy is required for image transformations.')
-
# x is a single image, so it doesn't have image number at index 0
img_row_axis = self.row_axis - 1
img_col_axis = self.col_axis - 1
@@ -648,25 +678,27 @@ class ImageDataGenerator(object):
# use composition of homographies
# to generate final transform that needs to be applied
if self.rotation_range:
- theta = np.pi / 180 * np.random.uniform(-self.rotation_range,
- self.rotation_range)
+ theta = np.deg2rad(
+ np.random.uniform(-self.rotation_range, self.rotation_range))
else:
theta = 0
if self.height_shift_range:
- tx = np.random.uniform(-self.height_shift_range,
- self.height_shift_range) * x.shape[img_row_axis]
+ tx = np.random.uniform(-self.height_shift_range, self.height_shift_range)
+ if self.height_shift_range < 1:
+ tx *= x.shape[img_row_axis]
else:
tx = 0
if self.width_shift_range:
- ty = np.random.uniform(-self.width_shift_range,
- self.width_shift_range) * x.shape[img_col_axis]
+ ty = np.random.uniform(-self.width_shift_range, self.width_shift_range)
+ if self.width_shift_range < 1:
+ ty *= x.shape[img_col_axis]
else:
ty = 0
if self.shear_range:
- shear = np.random.uniform(-self.shear_range, self.shear_range)
+ shear = np.deg2rad(np.random.uniform(-self.shear_range, self.shear_range))
else:
shear = 0
@@ -744,7 +776,7 @@ class ImageDataGenerator(object):
if x.ndim != 4:
raise ValueError('Input to `.fit()` should have rank 4. '
'Got array with shape: ' + str(x.shape))
- if x.shape[self.channel_axis] not in {3, 4}:
+ if x.shape[self.channel_axis] not in {1, 3, 4}:
logging.warning(
'Expected input to be images (as Numpy array) '
'following the data format convention "' + self.data_format + '" '
@@ -784,10 +816,12 @@ class ImageDataGenerator(object):
raise ImportError('Scipy is required for zca_whitening.')
flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
- sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
- u, s, _ = linalg.svd(sigma)
- self.principal_components = np.dot(
- np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T)
+ num_examples = flat_x.shape[0]
+ _, s, vt = linalg.svd(flat_x / np.sqrt(num_examples))
+ s_expand = np.hstack(
+ (s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype)))
+ self.principal_components = (
+ vt.T / np.sqrt(s_expand**2 + self.zca_epsilon)).dot(vt)
class Iterator(Sequence):
@@ -797,10 +831,10 @@ class Iterator(Sequence):
method.
Arguments:
- n: Integer, total number of samples in the dataset to loop over.
- batch_size: Integer, size of a batch.
- shuffle: Boolean, whether to shuffle the data between epochs.
- seed: Random seeding for data shuffling.
+ n: Integer, total number of samples in the dataset to loop over.
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seeding for data shuffling.
"""
def __init__(self, n, batch_size, shuffle, seed):
@@ -823,15 +857,14 @@ class Iterator(Sequence):
if idx >= len(self):
raise ValueError('Asked to retrieve element {idx}, '
'but the Sequence '
- 'has length {length}'.format(idx=idx,
- length=len(self)))
+ 'has length {length}'.format(idx=idx, length=len(self)))
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
self.total_batches_seen += 1
if self.index_array is None:
self._set_index_array()
- index_array = self.index_array[self.batch_size * idx:self.batch_size *
- (idx + 1)]
+ index_array = self.index_array[self.batch_size * idx:self.batch_size * (
+ idx + 1)]
return self._get_batches_of_transformed_samples(index_array)
def __len__(self):
@@ -873,6 +906,7 @@ class Iterator(Sequence):
Arguments:
index_array: array of sample indices to include in batch.
+
Returns:
A batch of transformed samples.
"""
@@ -948,8 +982,8 @@ class NumpyArrayIterator(Iterator):
seed)
def _get_batches_of_transformed_samples(self, index_array):
- batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
- dtype=K.floatx())
+ batch_x = np.zeros(
+ tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=K.floatx())
for i, j in enumerate(index_array):
x = self.x[j]
x = self.image_data_generator.random_transform(x.astype(K.floatx()))
@@ -959,7 +993,9 @@ class NumpyArrayIterator(Iterator):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(
- prefix=self.save_prefix, index=j, hash=np.random.randint(1e4),
+ prefix=self.save_prefix,
+ index=j,
+ hash=np.random.randint(1e4),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
if self.y is None:
@@ -984,10 +1020,11 @@ class NumpyArrayIterator(Iterator):
def _count_valid_files_in_directory(directory, white_list_formats,
follow_links):
- """Count files with extension in `white_list_formats` in a directory.
+ """Count files with extension in `white_list_formats` contained in directory.
Arguments:
- directory: absolute path to the directory containing files to be counted
+ directory: absolute path to the directory
+ containing files to be counted
white_list_formats: set of strings containing allowed extensions for
the files to be counted.
follow_links: boolean.
@@ -1003,7 +1040,7 @@ def _count_valid_files_in_directory(directory, white_list_formats,
samples = 0
for _, _, files in _recursive_list(directory):
- for fname in sorted(files):
+ for fname in files:
is_valid = False
for extension in white_list_formats:
if fname.lower().endswith('.' + extension):
@@ -1043,7 +1080,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats,
subdir = os.path.basename(directory)
basedir = os.path.dirname(directory)
for root, _, files in _recursive_list(directory):
- for fname in files:
+ for fname in sorted(files):
is_valid = False
for extension in white_list_formats:
if fname.lower().endswith('.' + extension):
@@ -1167,8 +1204,8 @@ class DirectoryIterator(Iterator):
white_list_formats=white_list_formats,
follow_links=follow_links)
self.samples = sum(
- pool.map(function_partial, (os.path.join(directory, subdir)
- for subdir in classes)))
+ pool.map(function_partial,
+ (os.path.join(directory, subdir) for subdir in classes)))
print('Found %d images belonging to %d classes.' % (self.samples,
self.num_classes))
@@ -1181,8 +1218,9 @@ class DirectoryIterator(Iterator):
i = 0
for dirpath in (os.path.join(directory, subdir) for subdir in classes):
results.append(
- pool.apply_async(_list_valid_filenames_in_directory, (
- dirpath, white_list_formats, self.class_indices, follow_links)))
+ pool.apply_async(
+ _list_valid_filenames_in_directory,
+ (dirpath, white_list_formats, self.class_indices, follow_links)))
for res in results:
classes, filenames = res.get()
self.classes[i:i + len(classes)] = classes
@@ -1199,10 +1237,11 @@ class DirectoryIterator(Iterator):
# build batch of image data
for i, j in enumerate(index_array):
fname = self.filenames[j]
- img = load_img(os.path.join(self.directory, fname),
- grayscale=grayscale,
- target_size=self.target_size,
- interpolation=self.interpolation)
+ img = load_img(
+ os.path.join(self.directory, fname),
+ grayscale=grayscale,
+ target_size=self.target_size,
+ interpolation=self.interpolation)
x = img_to_array(img, data_format=self.data_format)
x = self.image_data_generator.random_transform(x)
x = self.image_data_generator.standardize(x)
@@ -1212,7 +1251,9 @@ class DirectoryIterator(Iterator):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(
- prefix=self.save_prefix, index=j, hash=np.random.randint(1e7),
+ prefix=self.save_prefix,
+ index=j,
+ hash=np.random.randint(1e7),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
@@ -1241,4 +1282,3 @@ class DirectoryIterator(Iterator):
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
-
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
index 642f4f2fac..4d59250af0 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Preprocessing utilities for sequence data.
+"""Utilities for preprocessing sequence data.
"""
from __future__ import absolute_import
from __future__ import division
@@ -129,7 +129,7 @@ def make_sampling_table(size, sampling_factor=1e-5):
is the probability that a word of rank i should be sampled.
"""
gamma = 0.577
- rank = np.array(list(range(size)))
+ rank = np.arange(size)
rank[0] = 1
inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1. / (12. * rank)
f = sampling_factor * inv_fq
@@ -170,7 +170,7 @@ def skipgrams(sequence,
if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ]
sampling_table: 1D array of size `vocabulary_size` where the entry i
encodes the probability to sample a word of rank i.
- seed: Random seed.
+ seed: random seed.
Returns:
couples, labels: where `couples` are int pairs and
@@ -224,3 +224,22 @@ def skipgrams(sequence,
random.shuffle(labels)
return couples, labels
+
+
+def _remove_long_seq(maxlen, seq, label):
+ """Removes sequences that exceed the maximum length.
+
+ Arguments:
+ maxlen: int, maximum length
+ seq: list of lists where each sublist is a sequence
+ label: list where each element is an integer
+
+ Returns:
+ new_seq, new_label: shortened lists for `seq` and `label`.
+ """
+ new_seq, new_label = [], []
+ for x, y in zip(seq, label):
+ if len(x) < maxlen:
+ new_seq.append(x)
+ new_label.append(y)
+ return new_seq, new_label
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
index 47e5aa064f..8f7f25dc0a 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
@@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Utilities for text input preprocessing.
-
-May benefit from a fast Cython rewrite.
"""
from __future__ import absolute_import
from __future__ import division
@@ -29,6 +27,9 @@ import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.platform import tf_logging as logging
+
+
if sys.version_info < (3,):
maketrans = string.maketrans
else:
@@ -68,6 +69,21 @@ def one_hot(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
+ """One-hot encodes a text into a list of word indexes of size n.
+
+ This is a wrapper to the `hashing_trick` function using `hash` as the
+ hashing function; unicity of word to index mapping non-guaranteed.
+
+ Arguments:
+ text: Input text (string).
+ n: Dimension of the hashing space.
+ filters: Sequence of characters to filter out.
+ lower: Whether to convert the input to lowercase.
+ split: Sentence split marker (string).
+
+ Returns:
+ A list of integer word indices (unicity non-guaranteed).
+ """
return hashing_trick(
text, n, hash_function=hash, filters=filters, lower=lower, split=split)
@@ -99,6 +115,10 @@ def hashing_trick(text,
Two or more words may be assigned to the same index, due to possible
collisions by the hashing function.
+ The
+ probability
+ of a collision is in relation to the dimension of the hashing space and
+ the number of distinct objects.
"""
if hash_function is None:
hash_function = hash
@@ -127,6 +147,8 @@ class Tokenizer(object):
lower: boolean. Whether to convert the texts to lowercase.
split: character or string to use for token splitting.
char_level: if True, every character will be treated as a token.
+ oov_token: if given, it will be added to word_index and used to
+ replace out-of-vocabulary words during text_to_sequence calls
By default, all punctuation is removed, turning the texts into
space-separated sequences of words
@@ -141,7 +163,17 @@ class Tokenizer(object):
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' ',
- char_level=False):
+ char_level=False,
+ oov_token=None,
+ **kwargs):
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `Tokenizer` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
self.word_counts = OrderedDict()
self.word_docs = {}
self.filters = filters
@@ -150,6 +182,7 @@ class Tokenizer(object):
self.num_words = num_words
self.document_count = 0
self.char_level = char_level
+ self.oov_token = oov_token
def fit_on_texts(self, texts):
"""Updates internal vocabulary based on a list of texts.
@@ -181,7 +214,13 @@ class Tokenizer(object):
sorted_voc = [wc[0] for wc in wcounts]
# note that index 0 is reserved, never assigned to an existing word
self.word_index = dict(
- list(zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))))
+ list(zip(sorted_voc, list(range(1,
+ len(sorted_voc) + 1)))))
+
+ if self.oov_token is not None:
+ i = self.word_index.get(self.oov_token)
+ if i is None:
+ self.word_index[self.oov_token] = len(self.word_index) + 1
self.index_docs = {}
for w, c in list(self.word_docs.items()):
@@ -248,6 +287,10 @@ class Tokenizer(object):
continue
else:
vect.append(i)
+ elif self.oov_token is not None:
+ i = self.word_index.get(self.oov_token)
+ if i is not None:
+ vect.append(i)
yield vect
def texts_to_matrix(self, texts, mode='binary'):
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
index 17ab48ba3f..a934e331c4 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
@@ -76,6 +76,22 @@ class TestText(test.TestCase):
self.assertLessEqual(np.max(encoded), 4)
self.assertGreaterEqual(np.min(encoded), 1)
+ def test_tokenizer_oov_flag(self):
+ x_train = ['This text has only known words']
+ x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown
+
+ # Defalut, without OOV flag
+ tokenizer = keras.preprocessing.text.Tokenizer()
+ tokenizer.fit_on_texts(x_train)
+ x_test_seq = tokenizer.texts_to_sequences(x_test)
+ assert len(x_test_seq[0]) == 4 # discards 2 OOVs
+
+ # With OOV feature
+ tokenizer = keras.preprocessing.text.Tokenizer(oov_token='<unk>')
+ tokenizer.fit_on_texts(x_train)
+ x_test_seq = tokenizer.texts_to_sequences(x_test)
+ assert len(x_test_seq[0]) == 6 # OOVs marked in place
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py
index 161ff9bf5b..c53ee8a1ae 100644
--- a/tensorflow/python/keras/_impl/keras/regularizers.py
+++ b/tensorflow/python/keras/_impl/keras/regularizers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras built-in regularizers.
+"""Built-in regularizers.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
index d0be29f829..fcee9fbcc3 100644
--- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Utilities for file download and caching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import abstractmethod
+from contextlib import closing
import hashlib
import multiprocessing
from multiprocessing.pool import ThreadPool
@@ -39,10 +41,11 @@ from six.moves.urllib.request import urlopen
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
+
try:
- import queue # pylint:disable=g-import-not-at-top
+ import queue
except ImportError:
- import Queue as queue # pylint:disable=g-import-not-at-top
+ import Queue as queue
if sys.version_info[0] == 2:
@@ -86,7 +89,7 @@ if sys.version_info[0] == 2:
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
else:
- from six.moves.urllib.request import urlretrieve # pylint: disable=g-import-not-at-top
+ from six.moves.urllib.request import urlretrieve
def _extract_archive(file_path, path='.', archive_format='auto'):
@@ -186,7 +189,7 @@ def get_file(fname,
Path to the downloaded file
"""
if cache_dir is None:
- cache_dir = os.path.expanduser(os.path.join('~', '.keras'))
+ cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
if md5_hash is not None and file_hash is None:
file_hash = md5_hash
hash_algorithm = 'md5'
@@ -320,31 +323,41 @@ class Sequence(object):
Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
- `on_epoch_end`. The method `__getitem__` should return a complete batch.
+ `on_epoch_end`.
+ The method `__getitem__` should return a complete batch.
+
+ # Notes
- Notes:
`Sequence` are a safer way to do multiprocessing. This structure guarantees
- that the network will only train once on each sample per epoch which is not
- the case with generators.
+ that the network will only train once
+ on each sample per epoch which is not the case with generators.
+
Examples:
+
```python
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
+
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
+
class CIFAR10Sequence(Sequence):
+
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
+
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
+
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
- self.batch_size]
+ self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
- self.batch_size]
+ self.batch_size]
+
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
@@ -372,7 +385,6 @@ class Sequence(object):
"""
raise NotImplementedError
- @abstractmethod
def on_epoch_end(self):
"""Method called at the end of every epoch.
"""
@@ -470,35 +482,36 @@ class OrderedEnqueuer(SequenceEnqueuer):
Arguments:
sequence: A `keras.utils.data_utils.Sequence` object.
- use_multiprocessing: Use multiprocessing if True, otherwise threading
- shuffle: Whether to shuffle the data at the beginning of each epoch
+ use_multiprocessing: use multiprocessing if True, otherwise threading
+ shuffle: whether to shuffle the data at the beginning of each epoch
"""
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
self.sequence = sequence
self.use_multiprocessing = use_multiprocessing
- # Doing Multiprocessing.Value += x is not process-safe.
global _SEQUENCE_COUNTER
if _SEQUENCE_COUNTER is None:
- if self.use_multiprocessing:
+ try:
_SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
- else:
+ except OSError:
+ # In this case the OS does not allow us to use
+ # multiprocessing. We resort to an int
+ # for enqueuer indexing.
_SEQUENCE_COUNTER = 0
- if self.use_multiprocessing:
+ if isinstance(_SEQUENCE_COUNTER, int):
+ self.uid = _SEQUENCE_COUNTER
+ _SEQUENCE_COUNTER += 1
+ else:
+ # Doing Multiprocessing.Value += x is not process-safe.
with _SEQUENCE_COUNTER.get_lock():
self.uid = _SEQUENCE_COUNTER.value
_SEQUENCE_COUNTER.value += 1
- else:
- self.uid = _SEQUENCE_COUNTER
- if isinstance(_SEQUENCE_COUNTER, int):
- _SEQUENCE_COUNTER += 1
- else:
- _SEQUENCE_COUNTER.value += 1
+
self.shuffle = shuffle
self.workers = 0
- self.executor = None
+ self.executor_fn = None
self.queue = None
self.run_thread = None
self.stop_signal = None
@@ -515,9 +528,9 @@ class OrderedEnqueuer(SequenceEnqueuer):
(when full, workers could block on `put()`)
"""
if self.use_multiprocessing:
- self.executor = multiprocessing.Pool(workers)
+ self.executor_fn = lambda: multiprocessing.Pool(workers)
else:
- self.executor = ThreadPool(workers)
+ self.executor_fn = lambda: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
self.stop_signal = threading.Event()
@@ -533,24 +546,26 @@ class OrderedEnqueuer(SequenceEnqueuer):
return
def _run(self):
- """Function to submit request to the executor & queue `Future` objects."""
+ """Submits request to the executor and queue the `Future` objects."""
sequence = list(range(len(self.sequence)))
self._send_sequence() # Share the initial sequence
while True:
if self.shuffle:
random.shuffle(sequence)
- for i in sequence:
- if self.stop_signal.is_set():
- return
- self.queue.put(
- self.executor.apply_async(get_index, (self.uid, i)), block=True)
- # Done with the current epoch, waiting for the final batches
- self._wait_queue()
+ with closing(self.executor_fn()) as executor:
+ for i in sequence:
+ if self.stop_signal.is_set():
+ return
+ self.queue.put(
+ executor.apply_async(get_index, (self.uid, i)), block=True)
- if self.stop_signal.is_set():
- # We're done
- return
+ # Done with the current epoch, waiting for the final batches
+ self._wait_queue()
+
+ if self.stop_signal.is_set():
+ # We're done
+ return
# Call the internal on epoch end.
self.sequence.on_epoch_end()
@@ -562,8 +577,9 @@ class OrderedEnqueuer(SequenceEnqueuer):
Skip the data if it is `None`.
Yields:
- Tuples (inputs, targets)
- or (inputs, targets, sample_weights)
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
"""
try:
while self.is_running():
@@ -577,14 +593,8 @@ class OrderedEnqueuer(SequenceEnqueuer):
def _send_sequence(self):
"""Send current Sequence to all workers."""
- _SHARED_SEQUENCES[
- self.uid] = self.sequence # For new processes that may spawn
-
- self._close_pool()
- if self.use_multiprocessing:
- self.executor = multiprocessing.Pool(self.workers)
- else:
- self.executor = ThreadPool(self.workers)
+ # For new processes that may spawn
+ _SHARED_SEQUENCES[self.uid] = self.sequence
def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
@@ -599,14 +609,9 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
- self._close_pool()
self.run_thread.join(timeout)
_SHARED_SEQUENCES[self.uid] = None
- def _close_pool(self):
- self.executor.close()
- self.executor.join()
-
class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.
@@ -631,26 +636,53 @@ class GeneratorEnqueuer(SequenceEnqueuer):
seed=None):
self.wait_time = wait_time
self._generator = generator
- self._use_multiprocessing = use_multiprocessing
+ if os.name is 'nt' and use_multiprocessing is True:
+ # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`:
+ # `TypeError: can't pickle generator objects`
+ # => Suggest multithreading instead of multiprocessing on Windows
+ raise ValueError('Using a generator with `use_multiprocessing=True`'
+ ' is not supported on Windows (no marshalling of'
+ ' generators across process boundaries). Instead,'
+ ' use single thread/process or multithreading.')
+ else:
+ self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self._manager = None
self.queue = None
self.seed = seed
- def start(self, workers=1, max_queue_size=10):
- """Kicks off threads which add data from the generator into the queue.
-
- Arguments:
- workers: number of worker threads
- max_queue_size: queue size
- (when full, threads could block on `put()`)
- """
-
- def data_generator_task():
+ def _data_generator_task(self):
+ if self._use_multiprocessing is False:
+ while not self._stop_event.is_set():
+ with self.genlock:
+ try:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
+ # On all OSes, avoid **SYSTEMATIC** error
+ # in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to
+ # infinite iterator/generator's next() function
+ generator_output = next(self._generator)
+ self.queue.put((True, generator_output))
+ else:
+ time.sleep(self.wait_time)
+ except StopIteration:
+ break
+ except Exception as e: # pylint: disable=broad-except
+ # Can't pickle tracebacks.
+ # As a compromise, print the traceback and pickle None instead.
+ if not hasattr(e, '__traceback__'):
+ setattr(e, '__traceback__', sys.exc_info()[2])
+ self.queue.put((False, e))
+ self._stop_event.set()
+ break
+ else:
while not self._stop_event.is_set():
try:
- if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
generator_output = next(self._generator)
self.queue.put((True, generator_output))
else:
@@ -658,24 +690,34 @@ class GeneratorEnqueuer(SequenceEnqueuer):
except StopIteration:
break
except Exception as e: # pylint: disable=broad-except
- # Can't pick tracebacks.
+ # Can't pickle tracebacks.
# As a compromise, print the traceback and pickle None instead.
- if self._use_multiprocessing:
- traceback.print_exc()
- setattr(e, '__traceback__', None)
- elif not hasattr(e, '__traceback__'):
- setattr(e, '__traceback__', sys.exc_info()[2])
+ traceback.print_exc()
+ setattr(e, '__traceback__', None)
self.queue.put((False, e))
self._stop_event.set()
break
+ def start(self, workers=1, max_queue_size=10):
+ """Kicks off threads which add data from the generator into the queue.
+
+ Arguments:
+ workers: number of worker threads
+ max_queue_size: queue size
+ (when full, threads could block on `put()`)
+ """
try:
+ self.max_queue_size = max_queue_size
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
- self.queue = queue.Queue()
+ # On all OSes, avoid **SYSTEMATIC** error in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to infinite iterator/generator's next() function
+ self.genlock = threading.Lock()
+ self.queue = queue.Queue(maxsize=max_queue_size)
self._stop_event = threading.Event()
for _ in range(workers):
@@ -683,12 +725,12 @@ class GeneratorEnqueuer(SequenceEnqueuer):
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
- thread = multiprocessing.Process(target=data_generator_task)
+ thread = multiprocessing.Process(target=self._data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
- thread = threading.Thread(target=data_generator_task)
+ thread = threading.Thread(target=self._data_generator_task)
self._threads.append(thread)
thread.start()
except:
@@ -710,11 +752,15 @@ class GeneratorEnqueuer(SequenceEnqueuer):
self._stop_event.set()
for thread in self._threads:
- if thread.is_alive():
- if self._use_multiprocessing:
+ if self._use_multiprocessing:
+ if thread.is_alive():
thread.terminate()
- else:
- thread.join(timeout)
+ else:
+ # The thread.is_alive() test is subject to a race condition:
+ # the thread could terminate right after the test and before the
+ # join, rendering this test meaningless -> Call thread.join()
+ # always, which is ok no matter what the status of the thread.
+ thread.join(timeout)
if self._manager:
self._manager.shutdown()
@@ -729,7 +775,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
Skip the data if it is `None`.
Yields:
- Data arrays.
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
"""
while self.is_running():
if not self.queue.empty():
@@ -747,9 +795,8 @@ class GeneratorEnqueuer(SequenceEnqueuer):
else:
time.sleep(self.wait_time)
- # Make sure to rethrow the first exception in the queue, if any
+ # Make sure to rethrow the first exception in the queue, if any
while not self.queue.empty():
success, value = self.queue.get()
if not success:
six.reraise(value.__class__, value, value.__traceback__)
-
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index e9e54c2a2a..adbe6c3288 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import binascii
import codecs
import marshal
import os
@@ -29,10 +30,12 @@ import six
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.tf_export import tf_export
_GLOBAL_CUSTOM_OBJECTS = {}
+@tf_export('keras.utils.CustomObjectScope')
class CustomObjectScope(object):
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
@@ -68,6 +71,7 @@ class CustomObjectScope(object):
_GLOBAL_CUSTOM_OBJECTS.update(self.backup)
+@tf_export('keras.utils.custom_object_scope')
def custom_object_scope(*args):
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
@@ -98,6 +102,7 @@ def custom_object_scope(*args):
return CustomObjectScope(*args)
+@tf_export('keras.utils.get_custom_objects')
def get_custom_objects():
"""Retrieves a live reference to the global dictionary of custom objects.
@@ -118,6 +123,7 @@ def get_custom_objects():
return _GLOBAL_CUSTOM_OBJECTS
+@tf_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
_, instance = tf_decorator.unwrap(instance)
if instance is None:
@@ -133,6 +139,7 @@ def serialize_keras_object(instance):
raise ValueError('Cannot serialize', instance)
+@tf_export('keras.utils.deserialize_keras_object')
def deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
@@ -249,7 +256,10 @@ def func_load(code, defaults=None, closure=None, globs=None):
if closure is not None:
closure = tuple(ensure_value_to_cell(_) for _ in closure)
- raw_code = codecs.decode(code.encode('ascii'), 'base64')
+ try:
+ raw_code = codecs.decode(code.encode('ascii'), 'base64')
+ except (UnicodeEncodeError, binascii.Error):
+ raw_code = code.encode('raw_unicode_escape')
code = marshal.loads(raw_code)
if globs is None:
globs = globals()
@@ -275,6 +285,7 @@ def has_arg(fn, name, accept_all=False):
return name in arg_spec.args
+@tf_export('keras.utils.Progbar')
class Progbar(object):
"""Displays a progress bar.
diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
index a8fc18c17a..b36c769843 100644
--- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Utilities related to disk I/O."""
from __future__ import absolute_import
from __future__ import division
@@ -24,7 +25,7 @@ import numpy as np
try:
- import h5py # pylint:disable=g-import-not-at-top
+ import h5py
except ImportError:
h5py = None
@@ -63,11 +64,11 @@ class HDF5Matrix(object):
'HDF5 and h5py installed.')
if datapath not in list(self.refs.keys()):
- self._f = h5py.File(datapath)
- self.refs[datapath] = self._f
+ f = h5py.File(datapath)
+ self.refs[datapath] = f
else:
- self._f = self.refs[datapath]
- self.data = self._f[dataset]
+ f = self.refs[datapath]
+ self.data = f[dataset]
self.start = start
if end is None:
self.end = self.data.shape[0]
@@ -78,9 +79,6 @@ class HDF5Matrix(object):
def __len__(self):
return self.end - self.start
- def __del__(self):
- self._f.close()
-
def __getitem__(self, key):
if isinstance(key, slice):
start, stop = key.start, key.stop
diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
index 053c0600a3..a2d32424b5 100644
--- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities related to Keras layers.
+# pylint: disable=protected-access
+"""Utilities related to layer/model functionality.
"""
from __future__ import absolute_import
from __future__ import division
@@ -28,10 +29,10 @@ def count_params(weights):
"""Count the total number of scalars composing the weights.
Arguments:
- weights: An iterable containing the weights on which to compute params
+ weights: An iterable containing the weights on which to compute params
Returns:
- The total number of scalars composing the weights
+ The total number of scalars composing the weights
"""
return int(np.sum([K.count_params(p) for p in set(weights)]))
@@ -46,10 +47,11 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
terminal window sizes).
positions: Relative or absolute positions of log elements in each line.
If not provided, defaults to `[.33, .55, .67, 1.]`.
- print_fn: Print function to use (defaults to `print`).
+ print_fn: Print function to use.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
+ It defaults to `print` (prints to stdout).
"""
if print_fn is None:
print_fn = print
@@ -58,12 +60,13 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
sequential_like = True
else:
sequential_like = True
- nodes_by_depth = model._nodes_by_depth.values() # pylint: disable=protected-access
+ nodes_by_depth = model._nodes_by_depth.values()
nodes = []
for v in nodes_by_depth:
if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
- # If the model has multiple nodes or if the nodes have
- # multiple inbound_layers, the model is no longer sequential.
+ # if the model has multiple nodes
+ # or if the nodes have multiple inbound_layers
+ # the model is no longer sequential
sequential_like = False
break
nodes += v
@@ -71,7 +74,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
# search for shared layers
for layer in model.layers:
flag = False
- for node in layer.inbound_nodes:
+ for node in layer._inbound_nodes:
if node in nodes:
if flag:
sequential_like = False
@@ -96,7 +99,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
relevant_nodes = []
- for v in model._nodes_by_depth.values(): # pylint: disable=protected-access
+ for v in model._nodes_by_depth.values():
relevant_nodes += v
def print_row(fields, positions):
@@ -134,7 +137,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
except AttributeError:
output_shape = 'multiple'
connections = []
- for node in layer._inbound_nodes: # pylint: disable=protected-access
+ for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
@@ -142,8 +145,8 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
inbound_layer = node.inbound_layers[i].name
inbound_node_index = node.node_indices[i]
inbound_tensor_index = node.tensor_indices[i]
- connections.append(inbound_layer + '[' + str(inbound_node_index) + ']['
- + str(inbound_tensor_index) + ']')
+ connections.append(inbound_layer + '[' + str(inbound_node_index) +
+ '][' + str(inbound_tensor_index) + ']')
name = layer.name
cls_name = layer.__class__.__name__
@@ -172,9 +175,9 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
else:
print_fn('_' * line_length)
- model._check_trainable_weights_consistency() # pylint: disable=protected-access
+ model._check_trainable_weights_consistency()
if hasattr(model, '_collected_trainable_weights'):
- trainable_count = count_params(model._collected_trainable_weights) # pylint: disable=protected-access
+ trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)
diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
index 67d83bf42c..231833e776 100644
--- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils.py b/tensorflow/python/keras/_impl/keras/utils/training_utils.py
index 0bf4ac8a24..ce7402e9d2 100644
--- a/tensorflow/python/keras/_impl/keras/utils/training_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/training_utils.py
@@ -21,6 +21,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine.training import Model
from tensorflow.python.ops import array_ops
+from tensorflow.python.util.tf_export import tf_export
def _get_available_devices():
@@ -32,6 +33,7 @@ def _normalize_device_name(name):
return name
+@tf_export('keras.utils.multi_gpu_model')
def multi_gpu_model(model, gpus):
"""Replicates a model on different GPUs.
@@ -203,4 +205,3 @@ def multi_gpu_model(model, gpus):
for name, outputs in zip(model.output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)
-
diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
index d56c4484ce..0c5f2c19c7 100644
--- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,30 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=protected-access
+# pylint: disable=g-import-not-at-top
"""Utilities related to model visualization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
-import sys
+
try:
# pydot-ng is a fork of pydot that is better maintained.
- import pydot_ng as pydot # pylint: disable=g-import-not-at-top
+ import pydot_ng as pydot
except ImportError:
- # Fall back on pydot if necessary.
- # Silence a `print` statement that occurs in case of import error,
- # by temporarily replacing sys.stdout.
- _stdout = sys.stdout
- sys.stdout = sys.stderr
+ # pydotplus is an improved version of pydot
try:
- import pydot # pylint: disable=g-import-not-at-top
+ import pydotplus as pydot
except ImportError:
- pydot = None
- finally:
- # Restore sys.stdout.
- sys.stdout = _stdout
+ # Fall back on pydot if necessary.
+ try:
+ import pydot
+ except ImportError:
+ pydot = None
def _check_pydot():
@@ -65,8 +64,8 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
Returns:
A `pydot.Dot` instance representing the Keras model.
"""
- from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper # pylint: disable=g-import-not-at-top
- from tensorflow.python.keras._impl.keras.models import Sequential # pylint: disable=g-import-not-at-top
+ from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper
+ from tensorflow.python.keras._impl.keras.models import Sequential
_check_pydot()
dot = pydot.Dot()
@@ -118,9 +117,9 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
# Connect nodes with edges.
for layer in layers:
layer_id = str(id(layer))
- for i, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access
+ for i, node in enumerate(layer._inbound_nodes):
node_key = layer.name + '_ib-' + str(i)
- if node_key in model._network_nodes: # pylint: disable=protected-access
+ if node_key in model._container_nodes:
for inbound_layer in node.inbound_layers:
inbound_layer_id = str(id(inbound_layer))
layer_id = str(id(layer))
diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
index bc788d874f..223ceac3de 100644
--- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
+++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""API wrapper allowing to use certain Keras models with the Scikit-Learn API.
+"""Wrapper for using the Scikit-Learn API with Keras models.
"""
from __future__ import absolute_import
from __future__ import division
@@ -24,8 +24,8 @@ import types
import numpy as np
from tensorflow.python.keras._impl.keras.models import Sequential
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
-from tensorflow.python.util import tf_inspect
class BaseWrapper(object):
@@ -75,7 +75,7 @@ class BaseWrapper(object):
self.check_params(sk_params)
def check_params(self, params):
- """Checks for user typos in "params".
+ """Checks for user typos in `params`.
Arguments:
params: dictionary; the parameters to be checked
@@ -95,13 +95,11 @@ class BaseWrapper(object):
else:
legal_params_fns.append(self.build_fn)
- legal_params = []
- for fn in legal_params_fns:
- legal_params += tf_inspect.getargspec(fn)[0]
- legal_params = set(legal_params)
-
for params_name in params:
- if params_name not in legal_params:
+ for fn in legal_params_fns:
+ if has_arg(fn, params_name):
+ break
+ else:
if params_name != 'nb_epoch':
raise ValueError('{} is not a legal parameter'.format(params_name))
@@ -136,10 +134,10 @@ class BaseWrapper(object):
Arguments:
x : array-like, shape `(n_samples, n_features)`
- Training samples where n_samples in the number of samples
- and n_features is the number of features.
+ Training samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.fit`
@@ -170,21 +168,20 @@ class BaseWrapper(object):
return history
def filter_sk_params(self, fn, override=None):
- """Filters `sk_params` and return those in `fn`'s arguments.
+ """Filters `sk_params` and returns those in `fn`'s arguments.
Arguments:
fn : arbitrary function
- override: dictionary, values to override sk_params
+ override: dictionary, values to override `sk_params`
Returns:
- res : dictionary dictionary containing variables
- in both sk_params and fn's arguments.
+ res : dictionary containing variables
+ in both `sk_params` and `fn`'s arguments.
"""
override = override or {}
res = {}
- fn_args = tf_inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
- if name in fn_args:
+ if has_arg(fn, name):
res.update({name: value})
res.update(override)
return res
@@ -199,10 +196,10 @@ class KerasClassifier(BaseWrapper):
Arguments:
x : array-like, shape `(n_samples, n_features)`
- Training samples where n_samples in the number of samples
- and n_features is the number of features.
+ Training samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.fit`
@@ -229,8 +226,8 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `Sequential.predict_classes`.
@@ -248,8 +245,8 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `Sequential.predict_classes`.
@@ -258,8 +255,8 @@ class KerasClassifier(BaseWrapper):
proba: array-like, shape `(n_samples, n_outputs)`
Class probability estimates.
In the case of binary classification,
- tp match the scikit-learn API,
- will return an array of shape '(n_samples, 2)'
+ to match the scikit-learn API,
+ will return an array of shape `(n_samples, 2)`
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
@@ -276,16 +273,16 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for x.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.evaluate`.
Returns:
score: float
- Mean accuracy of predictions on X wrt. y.
+ Mean accuracy of predictions on `x` wrt. `y`.
Raises:
ValueError: If the underlying model isn't configured to
@@ -321,8 +318,8 @@ class KerasRegressor(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.predict`.
@@ -338,16 +335,16 @@ class KerasRegressor(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y: array-like, shape `(n_samples,)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.evaluate`.
Returns:
score: float
- Mean accuracy of predictions on X wrt. y.
+ Mean accuracy of predictions on `x` wrt. `y`.
"""
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
loss = self.model.evaluate(x, y, **kwargs)
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index 34f1435ffb..fccedf919a 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -18,16 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.keras.applications import densenet
from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
+from tensorflow.python.keras.applications import nasnet
from tensorflow.python.keras.applications import resnet50
from tensorflow.python.keras.applications import vgg16
from tensorflow.python.keras.applications import vgg19
from tensorflow.python.keras.applications import xception
+from tensorflow.python.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.applications.vgg19 import VGG19
diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/python/keras/applications/densenet/__init__.py
new file mode 100644
index 0000000000..6b8ea83920
--- /dev/null
+++ b/tensorflow/python/keras/applications/densenet/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""DenseNet Keras applications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201
+from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/nasnet/__init__.py b/tensorflow/python/keras/applications/nasnet/__init__.py
new file mode 100644
index 0000000000..94eb145b85
--- /dev/null
+++ b/tensorflow/python/keras/applications/nasnet/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""NASNet Keras applications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.nasnet import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile
+from tensorflow.python.keras._impl.keras.applications.nasnet import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index b94bf8f0f6..84ee5040dc 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras.layers.advanced_activations import Leak
from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU
from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU
from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import Softmax
# Convolution layers.
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D
@@ -37,6 +38,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D
# Convolution layer aliases.
@@ -45,6 +47,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D
# Image processing layers.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index de6aba4477..8c1d16c2a8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -87,6 +87,8 @@ cuda_py_test(
srcs = ["list_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:list_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -2488,6 +2490,7 @@ cuda_py_test(
"//tensorflow/python:sparse_ops",
],
shard_count = 5,
+ tags = ["noasan"],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 1dbe7deb97..a96b88d96f 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -82,7 +82,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(matrix_ph)
self.assertAllEqual(
- expected_transposed, transposed.eval(feed_dict={matrix_ph: matrix}))
+ expected_transposed, transposed.eval(feed_dict={
+ matrix_ph: matrix
+ }))
def testBatchMatrixDynamicallyDefined(self):
matrix_0 = [[1, 2, 3], [4, 5, 6]]
@@ -96,7 +98,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
transposed = array_ops.matrix_transpose(batch_matrix_ph)
self.assertAllEqual(
expected_transposed,
- transposed.eval(feed_dict={batch_matrix_ph: batch_matrix}))
+ transposed.eval(feed_dict={
+ batch_matrix_ph: batch_matrix
+ }))
def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
vector = [1, 2, 3]
@@ -203,8 +207,10 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
masked_tensor = sess.run(
array_ops.boolean_mask(ph_tensor, ph_mask),
- feed_dict={ph_tensor: arr,
- ph_mask: mask})
+ feed_dict={
+ ph_tensor: arr,
+ ph_mask: mask
+ })
np.testing.assert_allclose(masked_tensor, arr[mask])
def testMaskDimensionsSetToNoneRaises(self):
@@ -280,7 +286,8 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for axis_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(use_gpu=use_gpu):
x_tf = array_ops.reverse_v2(x_np,
- constant_op.constant([0], dtype=axis_dtype)).eval()
+ constant_op.constant(
+ [0], dtype=axis_dtype)).eval()
self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
def _reverse2DimAuto(self, np_dtype):
@@ -290,16 +297,17 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for use_gpu in [False, True]:
for axis_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(use_gpu=use_gpu):
- x_tf_1 = reverse_f(x_np,
- constant_op.constant([0], dtype=axis_dtype)).eval()
- x_tf_2 = reverse_f(x_np,
- constant_op.constant([-2], dtype=axis_dtype)).eval()
- x_tf_3 = reverse_f(x_np,
- constant_op.constant([1], dtype=axis_dtype)).eval()
- x_tf_4 = reverse_f(x_np,
- constant_op.constant([-1], dtype=axis_dtype)).eval()
+ x_tf_1 = reverse_f(x_np, constant_op.constant(
+ [0], dtype=axis_dtype)).eval()
+ x_tf_2 = reverse_f(x_np, constant_op.constant(
+ [-2], dtype=axis_dtype)).eval()
+ x_tf_3 = reverse_f(x_np, constant_op.constant(
+ [1], dtype=axis_dtype)).eval()
+ x_tf_4 = reverse_f(x_np, constant_op.constant(
+ [-1], dtype=axis_dtype)).eval()
x_tf_5 = reverse_f(x_np,
- constant_op.constant([1, 0], dtype=axis_dtype)).eval()
+ constant_op.constant([1, 0],
+ dtype=axis_dtype)).eval()
self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
@@ -324,18 +332,16 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testReverse1DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
- np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
+ np.float16, np.float32, np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
self._reverse1DimAuto(dtype)
def testReverse2DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
- np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
+ np.float16, np.float32, np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
self._reverse2DimAuto(dtype)
@@ -711,8 +717,8 @@ class GradSliceChecker(object):
slice_val_grad2, = gradients_impl.gradients(
slice_val_grad, dy, grad_ys=self.var)
self.sess.run(assign)
- slice_val_grad_evaled, slice_val_grad2_evaled = (self.sess.run(
- [slice_val_grad, slice_val_grad2]))
+ slice_val_grad_evaled, slice_val_grad2_evaled = (
+ self.sess.run([slice_val_grad, slice_val_grad2]))
analytic_grad2_evaled = analytic_grad2.eval()
self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled)
@@ -975,6 +981,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
self.assertEqual(2, array_ops.rank(sp).eval())
+@test_util.with_c_api
class SequenceMaskTest(test_util.TensorFlowTestCase):
def testExceptions(self):
@@ -986,36 +993,41 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
with self.test_session():
res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
self.assertAllEqual(res.get_shape(), [3, 5])
- self.assertAllEqual(res.eval(), [[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]])
+ self.assertAllEqual(
+ res.eval(),
+ [[True, False, False, False, False], [True, True, True, False, False],
+ [True, True, False, False, False]])
# test dtype and default maxlen:
res = array_ops.sequence_mask(
constant_op.constant([0, 1, 4]), dtype=dtypes.float32)
- self.assertAllEqual(res.get_shape().as_list(), [3, None])
- self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0,
- 0.0], [1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 1.0]])
+ if ops._USE_C_API:
+ self.assertAllEqual(res.get_shape().as_list(), [3, 4])
+ else:
+ self.assertAllEqual(res.get_shape().as_list(), [3, None])
+ self.assertAllEqual(
+ res.eval(),
+ [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
def testTwoDimensional(self):
with self.test_session():
res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
self.assertAllEqual(res.get_shape(), [1, 3, 5])
- self.assertAllEqual(res.eval(), [[[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]]])
+ self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [
+ True, True, True, False, False
+ ], [True, True, False, False, False]]])
# test dtype and default maxlen:
res = array_ops.sequence_mask(
constant_op.constant([[0, 1, 4], [1, 2, 3]]), dtype=dtypes.float32)
- self.assertAllEqual(res.get_shape().as_list(), [2, 3, None])
- self.assertAllEqual(res.eval(), [[[0.0, 0.0, 0.0, 0.0],
- [1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 1.0]],
- [[1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 0.0]]])
+ if ops._USE_C_API:
+ self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
+ else:
+ self.assertAllEqual(res.get_shape().as_list(), [2, 3, None])
+ self.assertAllEqual(
+ res.eval(),
+ [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
+ [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]])
def testDtypes(self):
@@ -1024,9 +1036,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
constant_op.constant([1, 3, 2], dtype=lengths_dtype),
constant_op.constant(5, dtype=maxlen_dtype))
self.assertAllEqual(res.get_shape(), [3, 5])
- self.assertAllEqual(res.eval(), [[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]])
+ self.assertAllEqual(
+ res.eval(),
+ [[True, False, False, False, False], [True, True, True, False, False],
+ [True, True, False, False, False]])
with self.test_session():
check_dtypes(dtypes.int32, dtypes.int32)
@@ -1081,13 +1094,14 @@ class PadTest(test_util.TensorFlowTestCase):
def testEager(self):
with context.eager_mode():
t = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- paddings = constant_op.constant([[1, 1,], [2, 2]])
+ paddings = constant_op.constant([[
+ 1,
+ 1,
+ ], [2, 2]])
padded = array_ops.pad(t, paddings, "CONSTANT")
self.assertAllEqual(padded.numpy(),
- [[0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 2, 3, 0, 0],
- [0, 0, 4, 5, 6, 0, 0],
- [0, 0, 0, 0, 0, 0, 0]])
+ [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0],
+ [0, 0, 4, 5, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
class InvertPermutationTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 030c690167..576bb68ba4 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -454,18 +454,20 @@ class ZerosLikeTest(test.TestCase):
def testZerosLikeCPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int8,
- dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, dtypes_lib.int32,
- dtypes_lib.int64, dtypes_lib.bool, dtypes_lib.complex64,
- dtypes_lib.complex128, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int8, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16,
+ dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.bool,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.string
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=False)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=False)
def testZerosLikeGPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
- dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int32, dtypes_lib.int64,
+ dtypes_lib.complex64, dtypes_lib.complex128,
+ dtypes_lib.bool
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True)
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py
index d92797a7d3..e2e6205911 100644
--- a/tensorflow/python/kernel_tests/conv1d_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_test.py
@@ -30,27 +30,29 @@ from tensorflow.python.platform import test
class Conv1DTest(test.TestCase):
def testBasic(self):
- """Test that argument passing to conv2d is handled properly."""
-
- x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32)
- x = array_ops.expand_dims(x, 0) # Add batch dimension
- x = array_ops.expand_dims(x, 2) # And depth dimension
- filters = constant_op.constant([2, 1], dtype=dtypes.float32)
- filters = array_ops.expand_dims(filters, 1) # in_channels
- filters = array_ops.expand_dims(filters, 2) # out_channels
- # Filters is 2x1x1
- for stride in [1, 2]:
- with self.test_session(use_gpu=test.is_gpu_available()):
- c = nn_ops.conv1d(x, filters, stride, padding="VALID")
- reduced = array_ops.squeeze(c)
- output = reduced.eval()
- if stride == 1:
- self.assertEqual(len(output), 3)
- self.assertAllClose(output,
- [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4])
- else:
- self.assertEqual(len(output), 2)
- self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
+ """Test that argument passing to conv1d is handled properly."""
+ # TODO(yongtang): dtypes.float64 can only be enabled once conv2d support
+ # dtypes.float64, as conv1d implicitly calls conv2d after expand_dims.
+ for dtype in [dtypes.float16, dtypes.float32]:
+ x = constant_op.constant([1, 2, 3, 4], dtype=dtype)
+ x = array_ops.expand_dims(x, 0) # Add batch dimension
+ x = array_ops.expand_dims(x, 2) # And depth dimension
+ filters = constant_op.constant([2, 1], dtype=dtype)
+ filters = array_ops.expand_dims(filters, 1) # in_channels
+ filters = array_ops.expand_dims(filters, 2) # out_channels
+ # Filters is 2x1x1
+ for stride in [1, 2]:
+ with self.test_session(use_gpu=test.is_gpu_available()):
+ c = nn_ops.conv1d(x, filters, stride, padding="VALID")
+ reduced = array_ops.squeeze(c)
+ output = reduced.eval()
+ if stride == 1:
+ self.assertEqual(len(output), 3)
+ self.assertAllClose(output,
+ [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4])
+ else:
+ self.assertEqual(len(output), 2)
+ self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
def testConv1DTranspose(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index cea12ea8ec..a91917b27f 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -1168,6 +1169,32 @@ class BinaryOpTest(test.TestCase):
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
+ def testPowNegativeExponent(self):
+ for dtype in [np.int32, np.int64]:
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([-2, 3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([2, -3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = -3
+ sess.run(math_ops.pow(x, y))
+
class ComparisonOpTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index 6cfa9b37fe..0825d8fc6b 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -84,11 +84,8 @@ class MatrixSetDiagTest(test.TestCase):
def testSquare(self):
with self.test_session(use_gpu=True):
v = np.array([1.0, 2.0, 3.0])
- mat = np.array([[0.0, 1.0, 0.0],
- [1.0, 0.0, 1.0],
- [1.0, 1.0, 1.0]])
- mat_set_diag = np.array([[1.0, 1.0, 0.0],
- [1.0, 2.0, 1.0],
+ mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
+ mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
[1.0, 1.0, 3.0]])
output = array_ops.matrix_set_diag(mat, v)
self.assertEqual((3, 3), output.get_shape())
@@ -135,19 +132,12 @@ class MatrixSetDiagTest(test.TestCase):
def testRectangularBatch(self):
with self.test_session(use_gpu=True):
- v_batch = np.array([[-1.0, -2.0],
- [-4.0, -5.0]])
- mat_batch = np.array(
- [[[1.0, 0.0, 3.0],
- [0.0, 2.0, 0.0]],
- [[4.0, 0.0, 4.0],
- [0.0, 5.0, 0.0]]])
-
- mat_set_diag_batch = np.array(
- [[[-1.0, 0.0, 3.0],
- [0.0, -2.0, 0.0]],
- [[-4.0, 0.0, 4.0],
- [0.0, -5.0, 0.0]]])
+ v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
+ mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
+
+ mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
output = array_ops.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 2, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, output.eval())
@@ -178,10 +168,14 @@ class MatrixSetDiagTest(test.TestCase):
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
y = array_ops.matrix_set_diag(x, x_diag)
error_x = gradient_checker.compute_gradient_error(
- x, x.get_shape().as_list(), y, y.get_shape().as_list())
+ x,
+ x.get_shape().as_list(), y,
+ y.get_shape().as_list())
self.assertLess(error_x, 1e-4)
error_x_diag = gradient_checker.compute_gradient_error(
- x_diag, x_diag.get_shape().as_list(), y, y.get_shape().as_list())
+ x_diag,
+ x_diag.get_shape().as_list(), y,
+ y.get_shape().as_list())
self.assertLess(error_x_diag, 1e-4)
def testGradWithNoShapeInformation(self):
@@ -192,12 +186,13 @@ class MatrixSetDiagTest(test.TestCase):
output = array_ops.matrix_set_diag(mat, v)
grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
grad_input_val = np.random.rand(3, 3).astype(np.float32)
- grad_vals = sess.run(grads,
- feed_dict={
- v: 2 * np.ones(3),
- mat: np.ones((3, 3)),
- grad_input: grad_input_val
- })
+ grad_vals = sess.run(
+ grads,
+ feed_dict={
+ v: 2 * np.ones(3),
+ mat: np.ones((3, 3)),
+ grad_input: grad_input_val
+ })
self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
grad_vals[0])
@@ -242,13 +237,9 @@ class MatrixDiagPartTest(test.TestCase):
def testRectangularBatch(self):
with self.test_session(use_gpu=True):
- v_batch = np.array([[1.0, 2.0],
- [4.0, 5.0]])
- mat_batch = np.array(
- [[[1.0, 0.0, 0.0],
- [0.0, 2.0, 0.0]],
- [[4.0, 0.0, 0.0],
- [0.0, 5.0, 0.0]]])
+ v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
+ mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
self.assertEqual(mat_batch.shape, (2, 2, 3))
mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
self.assertEqual((2, 2), mat_batch_diag.get_shape())
@@ -301,19 +292,13 @@ class DiagTest(test.TestCase):
def testRankOneIntTensor(self):
x = np.array([1, 2, 3])
- expected_ans = np.array(
- [[1, 0, 0],
- [0, 2, 0],
- [0, 0, 3]])
+ expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankOneFloatTensor(self):
x = np.array([1.1, 2.2, 3.3])
- expected_ans = np.array(
- [[1.1, 0, 0],
- [0, 2.2, 0],
- [0, 0, 3.3]])
+ expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
@@ -321,123 +306,105 @@ class DiagTest(test.TestCase):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
expected_ans = np.array(
- [[1.1 + 1.1j, 0 + 0j, 0 + 0j],
- [0 + 0j, 2.2 + 2.2j, 0 + 0j],
- [0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype=dtype)
+ [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
+ [0 + 0j, 0 + 0j, 3.3 + 3.3j]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankTwoIntTensor(self):
x = np.array([[1, 2, 3], [4, 5, 6]])
- expected_ans = np.array(
- [[[[1, 0, 0], [0, 0, 0]],
- [[0, 2, 0], [0, 0, 0]],
- [[0, 0, 3], [0, 0, 0]]],
- [[[0, 0, 0], [4, 0, 0]],
- [[0, 0, 0], [0, 5, 0]],
- [[0, 0, 0], [0, 0, 6]]]])
+ expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
+ [[0, 0, 3], [0, 0, 0]]],
+ [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
+ [[0, 0, 0], [0, 0, 6]]]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankTwoFloatTensor(self):
x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
expected_ans = np.array(
- [[[[1.1, 0, 0], [0, 0, 0]],
- [[0, 2.2, 0], [0, 0, 0]],
- [[0, 0, 3.3], [0, 0, 0]]],
- [[[0, 0, 0], [4.4, 0, 0]],
- [[0, 0, 0], [0, 5.5, 0]],
- [[0, 0, 0], [0, 0, 6.6]]]])
+ [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
+ [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
+ [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
+ [0, 0, 6.6]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankTwoComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
- x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
- [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype=dtype)
+ x = np.array(
+ [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
+ [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
+ dtype=dtype)
expected_ans = np.array(
- [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
- dtype=dtype)
+ [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
+ [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
+ [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankThreeFloatTensor(self):
- x = np.array([[[1.1, 2.2], [3.3, 4.4]],
- [[5.5, 6.6], [7.7, 8.8]]])
- expected_ans = np.array(
- [[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
- [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
- [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
- [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
- [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
- [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
- [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
- [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
+ x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
+ expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
+ [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
+ [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
+ [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
+ [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
+ [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
+ [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
+ [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankThreeComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
- x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
- [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
- dtype=dtype)
+ x = np.array(
+ [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
+ [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
+ dtype=dtype)
expected_ans = np.array(
- [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]],
- [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]],
- [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]],
- [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
+ [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
+ [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
+ ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
+ 0 + 0j, 0 + 0j
+ ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 7.7 + 7.7j, 0 + 0j
+ ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankFourNumberTensor(self):
for dtype in [np.float32, np.float64, np.int64, np.int32]:
# Input with shape [2, 1, 2, 3]
- x = np.array([[[[ 1, 2, 3],
- [ 4, 5, 6]]],
- [[[ 7, 8, 9],
- [10, 11, 12]]]], dtype=dtype)
+ x = np.array(
+ [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
# Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
expected_ans = np.array(
- [[[[[[[[1, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 2, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 3], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]]],
- [[[[[0, 0, 0], [4, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 5, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 6]]],
- [[[0, 0, 0], [0, 0, 0]]]]]]],
-
- [[[[[[[0, 0, 0], [0, 0, 0]]],
- [[[7, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 8, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 9], [0, 0, 0]]]]],
- [[[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [10, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 11, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 12]]]]]]]], dtype=dtype)
+ [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
+ [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
+ [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
+ [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
+ ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
+ [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
+ [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testInvalidRank(self):
@@ -537,7 +504,9 @@ class DiagGradOpTest(test.TestCase):
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag(x1)
error = gradient_checker.compute_gradient_error(
- x1, x1.get_shape().as_list(), y, y.get_shape().as_list())
+ x1,
+ x1.get_shape().as_list(), y,
+ y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
@@ -555,7 +524,9 @@ class DiagGradPartOpTest(test.TestCase):
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag_part(x1)
error = gradient_checker.compute_gradient_error(
- x1, x1.get_shape().as_list(), y, y.get_shape().as_list())
+ x1,
+ x1.get_shape().as_list(), y,
+ y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index 019c1bc353..ca2358fe99 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -100,6 +100,10 @@ class CategoricalTest(test.TestCase):
self.assertEqual(
dist.logits.dtype, dist.log_prob(np.array(
0, dtype=np.int64)).dtype)
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ dist = make_categorical([], 5, dtype=dtype)
+ self.assertEqual(dist.dtype, dtype)
+ self.assertEqual(dist.dtype, dist.sample(5).dtype)
def testUnknownShape(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 8fae044e2e..1577b7bc80 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -82,6 +83,21 @@ class ListOpsTest(test_util.TensorFlowTestCase):
with context.device("gpu:0"):
self.testTensorListFromTensor()
+ def testGetSetItem(self):
+ t = constant_op.constant([1.0, 2.0])
+ l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
+ e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
+ self.assertAllEqual(e0, 1.0)
+ l = list_ops.tensor_list_set_item(l, 0, 3.0)
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(t, [3.0, 2.0])
+
+ def testGetSetGPU(self):
+ if not context.num_gpus():
+ return
+ with context.device("gpu:0"):
+ self.testGetSetItem()
+
def testUnknownShape(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=-1)
@@ -159,6 +175,27 @@ class ListOpsTest(test_util.TensorFlowTestCase):
result = c2 * 2.0
self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
+ def testGetSetGradients(self):
+ with backprop.GradientTape() as tape:
+ c = constant_op.constant([1.0, 2.0])
+ tape.watch(c)
+ l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+ c2 = constant_op.constant(3.0)
+ tape.watch(c2)
+ l = list_ops.tensor_list_set_item(l, 0, c2)
+ e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
+ ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
+ y = e * e + ee * ee
+ grad_c, grad_c2 = tape.gradient(y, [c, c2])
+ self.assertAllEqual(grad_c, [0.0, 4.0])
+ self.assertAllEqual(grad_c2, 6.0)
+
+ def testSetOutOfBounds(self):
+ c = constant_op.constant([1.0, 2.0])
+ l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+ with self.assertRaises(errors.InvalidArgumentError):
+ list_ops.tensor_list_set_item(l, 20, 3.0)
+
if __name__ == "__main__":
ops.enable_eager_execution()
diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py
index 8b66945059..acfafde9e0 100644
--- a/tensorflow/python/kernel_tests/map_stage_op_test.py
+++ b/tensorflow/python/kernel_tests/map_stage_op_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.platform import test
TIMEOUT = 1
+
class MapStageTest(test.TestCase):
def testSimple(self):
@@ -83,7 +84,7 @@ class MapStageTest(test.TestCase):
[dtypes.float32, dtypes.float32],
shapes=[[], [128, 128]],
names=['x', 'v'])
- stage = stager.put(pi,{'x': x, 'v': v})
+ stage = stager.put(pi, {'x': x, 'v': v})
key, ret = stager.get(gi)
z = ret['x']
y = ret['v']
@@ -128,8 +129,11 @@ class MapStageTest(test.TestCase):
gi = array_ops.placeholder(dtypes.int64)
p = array_ops.placeholder(dtypes.int32, name='p')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ], shapes=[[]])
- stage = stager.put(pi,[x], [0])
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]])
+ stage = stager.put(pi, [x], [0])
peek = stager.peek(gi)
size = stager.size()
@@ -158,7 +162,7 @@ class MapStageTest(test.TestCase):
[dtypes.float32, dtypes.float32],
shapes=[[], [128, 128]],
names=['x', 'v'])
- stage = stager.put(pi,{'x': x, 'v': v})
+ stage = stager.put(pi, {'x': x, 'v': v})
size = stager.size()
clear = stager.clear()
@@ -172,7 +176,6 @@ class MapStageTest(test.TestCase):
sess.run(clear)
self.assertEqual(sess.run(size), 0)
-
def testCapacity(self):
capacity = 3
@@ -182,8 +185,10 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
- capacity=capacity, shapes=[[]])
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], capacity=capacity, shapes=[[]])
stage = stager.put(pi, [x], [0])
get = stager.get()
@@ -222,9 +227,8 @@ class MapStageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -236,8 +240,8 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
- memory_limit = 512*1024 # 512K
- chunk = 200*1024 # 256K
+ memory_limit = 512 * 1024 # 512K
+ chunk = 200 * 1024 # 256K
capacity = memory_limit // chunk
with ops.Graph().as_default() as G:
@@ -246,8 +250,8 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.uint8],
- memory_limit=memory_limit, shapes=[[]])
+ stager = data_flow_ops.MapStagingArea(
+ [dtypes.uint8], memory_limit=memory_limit, shapes=[[]])
stage = stager.put(pi, [x], [0])
get = stager.get()
size = stager.size()
@@ -287,9 +291,8 @@ class MapStageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -310,8 +313,10 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
- shapes=[[]], ordered=True)
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]], ordered=True)
stage = stager.put(pi, [x], [0])
get = stager.get()
size = stager.size()
@@ -349,7 +354,7 @@ class MapStageTest(test.TestCase):
stager = data_flow_ops.MapStagingArea(
[dtypes.float32, dtypes.float32, dtypes.float32],
names=['x', 'v', 'f'])
- stage_xf = stager.put(pi,{'x': x, 'f': f})
+ stage_xf = stager.put(pi, {'x': x, 'f': f})
stage_v = stager.put(pi, {'v': v})
key, ret = stager.get(gi)
size = stager.size()
@@ -373,12 +378,13 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 1])
# We can now obtain tuple associated with key 0
self.assertTrue(
- sess.run([key, ret],
- feed_dict={gi: 0}) == [0, {
- 'x': 1,
- 'f': 2,
- 'v': 1
- }])
+ sess.run([key, ret], feed_dict={
+ gi: 0
+ }) == [0, {
+ 'x': 1,
+ 'f': 2,
+ 'v': 1
+ }])
# 0 complete and 1 incomplete entry
self.assertTrue(sess.run([size, isize]) == [0, 1])
@@ -386,12 +392,13 @@ class MapStageTest(test.TestCase):
sess.run(stage_v, feed_dict={pi: 1, v: 3})
# We can now obtain tuple associated with key 1
self.assertTrue(
- sess.run([key, ret],
- feed_dict={gi: 1}) == [1, {
- 'x': 1,
- 'f': 2,
- 'v': 3
- }])
+ sess.run([key, ret], feed_dict={
+ gi: 1
+ }) == [1, {
+ 'x': 1,
+ 'f': 2,
+ 'v': 3
+ }])
def testPartialIndexInsert(self):
with ops.Graph().as_default() as G:
@@ -450,7 +457,7 @@ class MapStageTest(test.TestCase):
stager = data_flow_ops.MapStagingArea(
[dtypes.float32, dtypes.float32, dtypes.float32],
names=['x', 'v', 'f'])
- stage_xf = stager.put(pi,{'x': x, 'f': f})
+ stage_xf = stager.put(pi, {'x': x, 'f': f})
stage_v = stager.put(pi, {'v': v})
peek_xf = stager.peek(pei, ['x', 'f'])
peek_v = stager.peek(pei, ['v'])
@@ -487,11 +494,12 @@ class MapStageTest(test.TestCase):
# We can now obtain 'x' and 'f' values associated with key 0
self.assertTrue(
- sess.run([key_xf, get_xf],
- feed_dict={gi: 0}) == [0, {
- 'x': 1,
- 'f': 2
- }])
+ sess.run([key_xf, get_xf], feed_dict={
+ gi: 0
+ }) == [0, {
+ 'x': 1,
+ 'f': 2
+ }])
# Still have 1 complete and 1 incomplete entry
self.assertTrue(sess.run([size, isize]) == [1, 1])
@@ -499,14 +507,15 @@ class MapStageTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError) as cm:
sess.run([key_xf, get_xf], feed_dict={gi: 0})
- exc_str = ("Tensor at index '0' for key '0' "
- "has already been removed.")
+ exc_str = ("Tensor at index '0' for key '0' " 'has already been removed.')
self.assertTrue(exc_str in cm.exception.message)
# Obtain 'v' value associated with key 0
self.assertTrue(
- sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, {
+ sess.run([key_v, get_v], feed_dict={
+ gi: 0
+ }) == [0, {
'v': 1
}])
# 0 complete and 1 incomplete entry
@@ -523,7 +532,9 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 0])
# We can now obtain 'x' and 'f' values associated with key 1
self.assertTrue(
- sess.run([pop_key_v, pop_v], feed_dict={pi: 1}) == [1, {
+ sess.run([pop_key_v, pop_v], feed_dict={
+ pi: 1
+ }) == [1, {
'v': 1
}])
# Nothing is left
@@ -557,18 +568,20 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 0])
# Partial get using indices
- self.assertTrue(sess.run([key_xf, get_xf],
- feed_dict={gi: 0}) == [0, [1, 2]])
+ self.assertTrue(
+ sess.run([key_xf, get_xf], feed_dict={
+ gi: 0
+ }) == [0, [1, 2]])
# Still some of key 0 left
self.assertTrue(sess.run([size, isize]) == [1, 0])
# Partial get of remaining index
- self.assertTrue(sess.run([key_v, get_v],
- feed_dict={gi: 0}) == [0, [3]])
+ self.assertTrue(sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, [3]])
# All gone
self.assertTrue(sess.run([size, isize]) == [0, 0])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index 3358b78efd..e0e752147c 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -3628,7 +3628,8 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions=array_ops.ones([10, 1]),
labels=array_ops.ones([10, 1]),
num_classes=2)
- _assert_metric_variables(self, ('mean_accuracy/total_confusion_matrix:0',))
+ _assert_metric_variables(self, ('mean_accuracy/count:0',
+ 'mean_accuracy/total:0'))
def testMetricsCollections(self):
my_collection_name = '__metrics__'
@@ -3797,23 +3798,6 @@ class MeanPerClassAccuracyTest(test.TestCase):
desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
self.assertAlmostEqual(desired_output, mean_accuracy.eval())
- def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
- predictions = array_ops.concat([
- constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5])
- ], 0)
- labels = array_ops.concat([
- constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7])
- ], 0)
- num_classes = 2
- with self.test_session() as sess:
- mean_accuracy, update_op = metrics.mean_per_class_accuracy(
- labels, predictions, num_classes)
- sess.run(variables.local_variables_initializer())
- confusion_matrix = update_op.eval()
- self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
- desired_mean_accuracy = np.mean([3. / 3., 5. / 7.])
- self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())
-
def testAllCorrect(self):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
@@ -3822,7 +3806,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
- self.assertEqual(40, update_op.eval()[0])
+ self.assertEqual(1.0, update_op.eval()[0])
self.assertEqual(1.0, mean_accuracy.eval())
def testAllWrong(self):
@@ -3833,7 +3817,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
- self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
+ self.assertAllEqual([0.0, 0.0], update_op.eval())
self.assertEqual(0., mean_accuracy.eval())
def testResultsWithSomeMissing(self):
@@ -3852,8 +3836,9 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
- self.assertAllEqual([[2, 0], [2, 4]], update_op.eval())
- desired_mean_accuracy = np.mean([2. / 2., 4. / 6.])
+ desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
+ self.assertAllEqual(desired_accuracy, update_op.eval())
+ desired_mean_accuracy = np.mean(desired_accuracy)
self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 5c0ea8ec8e..3263ed1a60 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -159,8 +159,10 @@ class PoolingTest(test.TestCase):
elif data_format == "NCHW":
t = test_util.NCHWToNHWC(t)
if v2:
- actual = t.eval(feed_dict={ksize_placeholder: ksize,
- strides_placeholder: strides})
+ actual = t.eval(feed_dict={
+ ksize_placeholder: ksize,
+ strides_placeholder: strides
+ })
else:
actual = t.eval()
self.assertShapeEqual(actual, t)
@@ -195,8 +197,15 @@ class PoolingTest(test.TestCase):
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float16, expected, use_gpu, v2)
- def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
- expected, use_gpu, v2=False):
+ def _VerifyValues(self,
+ pool_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ expected,
+ use_gpu,
+ v2=False):
"""Verifies the output values of the pooling function.
Args:
@@ -1148,16 +1157,16 @@ class PoolingTest(test.TestCase):
def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu):
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
self._ConstructAndTestGradient(
- pool_func,
- input_sizes=[1, 7, 7, 1],
- output_sizes=[1, 7, 7, 1],
- window_rows=3,
- window_cols=3,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ pool_func,
+ input_sizes=[1, 7, 7, 1],
+ output_sizes=[1, 7, 7, 1],
+ window_rows=3,
+ window_cols=3,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def testMaxPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
@@ -1202,17 +1211,14 @@ class PoolingTest(test.TestCase):
pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool
with self.test_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(input_data, shape=input_sizes)
- output_tensor = pool_func(input_tensor,
- [1, window_rows, window_cols, 1],
+ output_tensor = pool_func(input_tensor, [1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding)
output_backprop_tensor = constant_op.constant(
output_backprop, shape=output_sizes)
- input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor,
- output_backprop_tensor,
- window_rows, window_cols,
- row_stride, col_stride,
- padding, v2)
+ input_backprop_tensor = self._MaxPoolGrad(
+ input_tensor, output_tensor, output_backprop_tensor, window_rows,
+ window_cols, row_stride, col_stride, padding, v2)
actual_input_backprop = input_backprop_tensor.eval()
self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
@@ -1414,13 +1420,15 @@ class PoolingTest(test.TestCase):
def _testMaxPoolGradDirectWithNans2_2(self):
input_data = [float("nan")] * 16
output_backprop = [
- float("nan"), 12.0, 13.0, 15.0, float("nan"), 17.0, 19.0, 20.0,
+ float("nan"), 12.0, 13.0, 15.0,
+ float("nan"), 17.0, 19.0, 20.0,
float("nan")
]
# Test the CPU implementation, which propagates diffs in case of NaN
expected_input_backprop_tf_cpu = [
- float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0,
- 20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
+ float("nan"), 12.0, 13.0, 0.0, 15.0,
+ float("nan"), 17.0, 0.0, 19.0, 20.0,
+ float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
]
for v2 in [True, False]:
self._testMaxPoolGradDirect(
@@ -1636,10 +1644,9 @@ class PoolingTest(test.TestCase):
Returns:
A Tensor.
"""
- return gen_nn_ops._max_pool_grad_grad(orig_input, orig_output, grad,
- [1, window_rows, window_cols,
- 1], [1, row_stride, col_stride,
- 1], padding)
+ return gen_nn_ops._max_pool_grad_grad(
+ orig_input, orig_output, grad, [1, window_rows, window_cols, 1],
+ [1, row_stride, col_stride, 1], padding)
def testAvgPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
@@ -1793,8 +1800,7 @@ class PoolingTest(test.TestCase):
]:
with self.assertRaises(ValueError):
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
ksize=[1, 1, 1, 1],
strides=[1, 1, 1, 1],
padding="SAME")
@@ -1820,15 +1826,13 @@ class PoolingTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
sess.run(
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
ksize=[1, 20, 21, 1],
strides=[1, 1, 1, 1],
padding="VALID"))
with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
ksize=[1, 21, 20, 1],
strides=[1, 1, 1, 1],
padding="VALID")
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 223a4b2c87..82a27eebee 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -428,7 +428,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
filenames.append(fn)
- with open(fn+".tmp", "wb") as f:
+ with open(fn + ".tmp", "wb") as f:
f.write(b"H" * self._header_bytes)
if num_records > 0:
f.write(self._Record(i, 0))
@@ -437,7 +437,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
f.write(b"G" * gap_bytes)
f.write(self._Record(i, j))
f.write(b"F" * self._footer_bytes)
- with open(fn+".tmp", "rb") as f:
+ with open(fn + ".tmp", "rb") as f:
cdata = zlib.compress(f.read())
with open(fn, "wb") as zf:
zf.write(cdata)
@@ -455,7 +455,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
all_records_str = "".join([
str(i)[0]
for i in range(self._record_bytes + self._hop_bytes *
- (num_overlapped_records - 1))
+ (num_overlapped_records - 1))
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
@@ -467,7 +467,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
fn = os.path.join(self.get_temp_dir(),
"fixed_length_overlapped_record.%d.txt" % i)
filenames.append(fn)
- with open(fn+".tmp", "wb") as f:
+ with open(fn + ".tmp", "wb") as f:
f.write(b"H" * self._header_bytes)
if num_overlapped_records > 0:
all_records_str = "".join([
@@ -477,7 +477,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
- with open(fn+".tmp", "rb") as f:
+ with open(fn + ".tmp", "rb") as f:
cdata = zlib.compress(f.read())
with open(fn, "wb") as zf:
zf.write(cdata)
@@ -509,7 +509,10 @@ class FixedLengthRecordReaderTest(test.TestCase):
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
- def _TestOneEpochWithHopBytes(self, files, num_overlapped_records, encoding=None):
+ def _TestOneEpochWithHopBytes(self,
+ files,
+ num_overlapped_records,
+ encoding=None):
with self.test_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
@@ -565,13 +568,15 @@ class FixedLengthRecordReaderTest(test.TestCase):
def testGzipOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
- files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records, )
- self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="GZIP")
+ files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,)
+ self._TestOneEpochWithHopBytes(
+ files, num_overlapped_records, encoding="GZIP")
def testZlibOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records)
- self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="ZLIB")
+ self._TestOneEpochWithHopBytes(
+ files, num_overlapped_records, encoding="ZLIB")
class TFRecordReaderTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index dd11ba700d..6b4091ae5d 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -48,8 +48,8 @@ class ReluTest(test.TestCase):
self.assertAllClose(
np.array([[0.0, 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
self._npRelu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testRelu(self, np_features, use_gpu=False):
np_relu = self._npRelu(np_features)
@@ -163,8 +163,8 @@ class Relu6Test(test.TestCase):
self.assertAllClose(
np.array([[0.0, 0.7, 0.0, 0.3, 6.0], [0.1, 0.0, 6.0, 0.0, 0.9]]),
self._npRelu6(
- np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7,
+ 0.9]])))
def _testRelu6(self, np_features, use_gpu=False):
np_relu6 = self._npRelu6(np_features)
@@ -231,8 +231,8 @@ class EluTest(test.TestCase):
np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196],
[0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]),
self._npElu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testElu(self, np_features, use_gpu=False):
np_elu = self._npElu(np_features)
@@ -330,11 +330,11 @@ class SeluTest(test.TestCase):
def testNpSelu(self):
self.assertAllClose(
- np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527],
- [0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
+ np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103, -0.16730527],
+ [0.1050701, -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
self._npSelu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testSelu(self, np_features, use_gpu=False):
np_selu = self._npSelu(np_features)
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 7b131a5b8c..b4b555591d 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+@test_util.with_c_api
class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def tearDown(self):
@@ -342,14 +343,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(
2.0, caching_device="/job:localhost")
self.assertEqual("/job:localhost", v.value().device)
- with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
+ with self.assertRaises(ValueError):
_ = v.value().op.get_attr("_class")
with ops.colocate_with(v.op):
w = resource_variable_ops.ResourceVariable(
2.0, caching_device="/job:localhost")
self.assertEqual("/job:localhost", w.value().device)
- with self.assertRaisesRegexp(ValueError, "No attr named '_class'"):
+ with self.assertRaises(ValueError):
_ = w.value().op.get_attr("_class")
def testSharedName(self):
diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py
index b34426cc21..e65241981e 100644
--- a/tensorflow/python/kernel_tests/scalar_test.py
+++ b/tensorflow/python/kernel_tests/scalar_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
@@ -30,6 +31,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+@test_util.with_c_api
class ScalarTest(test.TestCase):
def check(self, op, args, error, correct=None):
@@ -51,7 +53,7 @@ class ScalarTest(test.TestCase):
# Test various GraphDef versions
for version in strict + lenient:
with ops.Graph().as_default() as g:
- g.graph_def_versions.producer = version
+ test_util.set_producer_version(g, version)
with self.test_session(graph=g) as sess:
feed = {}
xs = placeholders(args, feed)
diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
index 762e400447..da116601f8 100644
--- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
@@ -32,11 +32,12 @@ class SparseSliceOpTest(test.TestCase):
# [ |11| |13|14| ]
# [20| | |23| |25]
# [30| |32|33| |35]
- ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4],
- [2, 0], [2, 3], [2, 5], [3, 0], [3, 2], [3, 3],
- [3, 5]]).astype(np.int64)
- val = np.array(
- [0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(np.int64)
+ ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1,
+ 4], [2, 0],
+ [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
+ np.int64)
+ val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
+ np.int64)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)
@@ -65,50 +66,49 @@ class SparseSliceOpTest(test.TestCase):
# [ |'c1'| |'d1']
# [ | |'e1'| ]
ind = np.array([[0, 0, 0], [0, 0, 1], [0, 2, 0], [0, 2, 1], [1, 1, 0],
- [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0],
- [2, 2, 1]]).astype(np.int64)
+ [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0], [2, 2,
+ 1]]).astype(
+ np.int64)
val = np.array(['a0', 'a1', 'b0', 'b1', 'c0', 'c1', 'd0', 'd1', 'e0', 'e1'])
shape = np.array([3, 4, 2]).astype(np.int64)
return sparse_tensor.SparseTensorValue(ind, val, shape)
def _SparseTensor_3x4x2(self):
- return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2(
- ))
+ return sparse_tensor.SparseTensor.from_value(
+ self._SparseTensorValue_3x4x2())
def testSliceMatrixRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 6])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [3, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4],
- [0, 5], [1, 1], [1, 3],
- [1, 4]])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]])
self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5, 11, 13, 14])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 6])
- self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 0], [0, 3], [0, 5],
- [1, 0], [1, 2], [1, 3],
- [1, 5]])
+ self.assertAllEqual(
+ sp_tensor1.indices.eval(),
+ [[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]])
self.assertAllEqual(sp_tensor1.values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 6])
def testSliceMatrixUnevenCols(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_5x7()
+ sp_input = self._SparseTensor_5x7()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [5, 3])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 3], [5, 2])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 5], [5, 2])
- self.assertAllEqual(sp_tensor0.indices.eval(),
- [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2],
- [4, 1]])
- self.assertAllEqual(sp_tensor0.values.eval(),
- [0, 2, 11, 20, 30, 32, 41])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]])
+ self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 11, 20, 30, 32, 41])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [5, 3])
self.assertAllEqual(sp_tensor1.indices.eval(),
[[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
- self.assertAllEqual(sp_tensor1.values.eval(),
- [4, 13, 14, 23, 33, 44])
+ self.assertAllEqual(sp_tensor1.values.eval(), [4, 13, 14, 23, 33, 44])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensor2.indices.eval(),
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
@@ -137,7 +137,7 @@ class SparseSliceOpTest(test.TestCase):
def testSliceMatrixUnevenRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_5x7()
+ sp_input = self._SparseTensor_5x7()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [3, 7])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [3, 0], [3, 7])
self.assertAllEqual(sp_tensor0.indices.eval(),
@@ -146,9 +146,9 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sp_tensor0.values.eval(),
[0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [3, 7])
- self.assertAllEqual(sp_tensor1.indices.eval(),
- [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4],
- [1, 6]])
+ self.assertAllEqual(
+ sp_tensor1.indices.eval(),
+ [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensor1.values.eval(),
[30, 32, 33, 35, 41, 44, 46])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7])
@@ -156,9 +156,9 @@ class SparseSliceOpTest(test.TestCase):
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 7])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [2, 7])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [4, 0], [2, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(),
- [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
- [1, 4], [1, 6]])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensor0.values.eval(),
[0, 2, 4, 5, 11, 13, 14, 16])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 7])
@@ -166,45 +166,42 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sp_tensor1.values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7])
- self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4],
- [0, 6]])
+ self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4], [0, 6]])
self.assertAllEqual(sp_tensor2.values.eval(), [41, 44, 46])
self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 7])
return
def testSliceAllRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [1, 6])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [1, 0], [1, 6])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [2, 0], [1, 7])
sp_tensor3 = sparse_ops.sparse_slice(sp_input, [3, 0], [2, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4],
- [0, 5]])
+ self.assertAllEqual(sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5]])
self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0,
- 4]])
+ self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0, 4]])
self.assertAllEqual(sp_tensor1.values.eval(), [11, 13, 14])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0,
- 5]])
+ self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensor2.values.eval(), [20, 23, 25])
self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor3.indices.eval(), [[0, 0], [0, 2], [0, 3],
- [0, 5]])
+ self.assertAllEqual(sp_tensor3.indices.eval(),
+ [[0, 0], [0, 2], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensor3.values.eval(), [30, 32, 33, 35])
self.assertAllEqual(sp_tensor3.dense_shape.eval(), [1, 6])
def testSliceColumns(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 2])
sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 2], [5, 2])
sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 3])
- self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [1, 1],
- [2, 0], [3, 0]])
+ self.assertAllEqual(sparse_tensor0.indices.eval(),
+ [[0, 0], [1, 1], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor0.values.eval(), [0, 11, 20, 30])
self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 2])
self.assertAllEqual(sparse_tensor1.indices.eval(),
@@ -218,15 +215,15 @@ class SparseSliceOpTest(test.TestCase):
def testSliceAllColumns(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 1])
sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 1], [4, 1])
sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 2], [4, 1])
sparse_tensor3 = sparse_ops.sparse_slice(sp_input, [0, 3], [4, 1])
sparse_tensor4 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 1])
sparse_tensor5 = sparse_ops.sparse_slice(sp_input, [0, 5], [6, 3])
- self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor0.indices.eval(),
+ [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor0.values.eval(), [0, 20, 30])
self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensor1.indices.eval(), [[1, 0]])
@@ -235,17 +232,18 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sparse_tensor2.indices.eval(), [[0, 0], [3, 0]])
self.assertAllEqual(sparse_tensor2.values.eval(), [2, 32])
self.assertAllEqual(sparse_tensor2.dense_shape.eval(), [4, 1])
- self.assertAllEqual(sparse_tensor3.indices.eval(), [[1, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor3.indices.eval(),
+ [[1, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor3.dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensor3.values.eval(), [13, 23, 33])
self.assertAllEqual(sparse_tensor4.indices.eval(), [[0, 0], [1, 0]])
self.assertAllEqual(sparse_tensor4.values.eval(), [4, 14])
self.assertAllEqual(sparse_tensor4.dense_shape.eval(), [4, 1])
- self.assertAllEqual(sparse_tensor5.indices.eval(), [[0, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor5.indices.eval(),
+ [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index 64b3388c5c..dd06d30391 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -25,8 +25,8 @@ from tensorflow.python.platform import test
TIMEOUT = 1
-class StageTest(test.TestCase):
+class StageTest(test.TestCase):
def testSimple(self):
with ops.Graph().as_default() as G:
@@ -116,7 +116,10 @@ class StageTest(test.TestCase):
x = array_ops.placeholder(dtypes.int32, name='x')
p = array_ops.placeholder(dtypes.int32, name='p')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.int32, ], shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]])
stage = stager.put([x])
peek = stager.peek(p)
ret = stager.get()
@@ -162,8 +165,10 @@ class StageTest(test.TestCase):
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.int32, name='x')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.int32, ],
- capacity=capacity, shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.int32,
+ ], capacity=capacity, shapes=[[]])
stage = stager.put([x])
ret = stager.get()
size = stager.size()
@@ -201,9 +206,8 @@ class StageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -216,16 +220,18 @@ class StageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
- memory_limit = 512*1024 # 512K
- chunk = 200*1024 # 256K
+ memory_limit = 512 * 1024 # 512K
+ chunk = 200 * 1024 # 256K
capacity = memory_limit // chunk
with ops.Graph().as_default() as G:
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.uint8, name='x')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.uint8, ],
- memory_limit=memory_limit, shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.uint8,
+ ], memory_limit=memory_limit, shapes=[[]])
stage = stager.put([x])
ret = stager.get()
size = stager.size()
@@ -264,9 +270,8 @@ class StageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -277,5 +282,6 @@ class StageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index f375157287..38205518b5 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -64,7 +64,7 @@ class TensordotTest(test_lib.TestCase):
a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4]]
# Invalid static axes.
- for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]:
+ for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, axes_value)
@@ -87,7 +87,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
- for axes_value in [1, 2], [[1], [2]]:
+ for axes_value in [1, 2], [[1], [2]], [[], []], 0:
with self.test_session() as sess:
np_a = np.ones((3,3))
np_b = np.array([2, 3, 1])[None, None]
@@ -102,29 +102,29 @@ class TensordotTest(test_lib.TestCase):
def test_partial_shape_inference(self):
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- axes = ([1], [0])
- output = math_ops.tensordot(a, b, axes)
- self.assertEqual(output.get_shape().ndims, None)
- a.set_shape([None, 2])
- b.set_shape([2, 3])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], None)
- self.assertEqual(output_shape[1], 3)
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- a.set_shape([2, 2])
- b.set_shape([2, None])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], 2)
- self.assertEqual(output_shape[1], None)
+ for axes in ([1],[0]), 1:
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ output = math_ops.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a.set_shape([None, 2])
+ b.set_shape([2, 3])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], None)
+ self.assertEqual(output_shape[1], 3)
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ a.set_shape([2, 2])
+ b.set_shape([2, None])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], 2)
+ self.assertEqual(output_shape[1], None)
def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
@@ -191,8 +191,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
b_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
- all_axes = [1]
- if a_np.ndim > 1:
+ all_axes = [0, 1]
+ if a_np.ndim > 2:
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index f1a86625e0..8527f116f9 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -131,6 +131,30 @@ class VariableScopeTest(test.TestCase):
self.assertFalse(v in store.non_trainable_variables())
self.assertTrue(w in store.non_trainable_variables())
+ # Test copying.
+ new_store = store.copy()
+ with new_store.as_default():
+ new_v = variable_scope.get_variable("v")
+ new_w = variable_scope.get_variable("w")
+ self.assertEqual(new_v.numpy(), v.numpy())
+ self.assertEqual(new_w.numpy(), w.numpy())
+ self.assertTrue(new_v in new_store.variables())
+ self.assertTrue(new_w in new_store.variables())
+ self.assertTrue(new_v in new_store.trainable_variables())
+ self.assertFalse(new_w in new_store.trainable_variables())
+ self.assertFalse(new_v in new_store.non_trainable_variables())
+ self.assertTrue(new_w in new_store.non_trainable_variables())
+
+ # Check that variables are separate instances.
+ for v in store.variables():
+ v.assign(-1)
+ for v in new_store.variables():
+ v.assign(1)
+ for v in store.variables():
+ self.assertEqual(v.numpy(), -1)
+ for v in new_store.variables():
+ self.assertEqual(v.numpy(), 1)
+
@test_util.run_in_graph_and_eager_modes()
def testInitFromNonTensorValue(self):
v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
@@ -1253,6 +1277,24 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
(((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3]))
+ ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
+ def testVariableCreator(self):
+
+ variable_names = []
+
+ def creator_a(next_creator, **kwargs):
+ variable_names.append(kwargs.get("name", ""))
+ return next_creator(**kwargs)
+
+ def creator_b(next_creator, **kwargs):
+ kwargs["name"] = "forced_name"
+ return next_creator(**kwargs)
+
+ with variable_scope.variable_creator_scope(creator_a):
+ with variable_scope.variable_creator_scope(creator_b):
+ variable_scope.variable(1.0, name="one_name")
+
+ self.assertAllEqual(variable_names, ["forced_name"])
+
class PartitionInfoTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 43be08f8a1..c6c7c4e26c 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -240,6 +240,16 @@ class XentTest(test.TestCase):
self._testXentWrapper(features, labels, dim=-1, use_gpu=False)
self._testXentWrapper(features, labels, dim=-1, use_gpu=True)
+ def testZeroDimension(self):
+ features = np.zeros([0, 2, 4]).astype(np.float32)
+ labels = np.zeros([0, 2, 4]).astype(np.float32)
+ np_loss, _ = self._npXent(features, labels)
+ with self.test_session(use_gpu=True) as sess:
+ loss = nn_ops.softmax_cross_entropy_with_logits(
+ labels=labels, logits=features)
+ tf_loss = sess.run(loss)
+ self.assertAllEqual(np_loss, tf_loss)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 00faf3faa1..5d9feb07b4 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -99,8 +99,16 @@ class Layer(object):
raise TypeError('Keyword argument not understood:', kwarg)
# Mutable properties
+ # Indicates whether the layer's weights are updated during training
+ # and whether the layer's updates are run during training
self.trainable = trainable
+ # A stateful layer is a layer whose updates are run during inference too,
+ # for instance stateful RNNs.
+ self.stateful = False
+ # Indicates whether `build` needs to be called upon layer call, to create
+ # the layer's weights.
self.built = False
+ # Provides information about which inputs are compatible with the layer.
self.input_spec = None
if activity_regularizer and context.in_eager_mode():
@@ -223,6 +231,8 @@ class Layer(object):
def updates(self):
if context.in_eager_mode():
raise RuntimeError('Layer.updates not supported in Eager mode.')
+ if not self.trainable and not self.stateful:
+ return []
return self._updates
def add_update(self, updates, inputs=None):
@@ -284,6 +294,8 @@ class Layer(object):
"""
if context.in_eager_mode():
raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
+ if not self.trainable and not self.stateful:
+ return []
if inputs is not None:
inputs = nest.flatten(inputs)
if not inputs:
@@ -500,13 +512,30 @@ class Layer(object):
instance is returned.
Raises:
- RuntimeError: If called in Eager mode with partioned variable
- regularization.
+ RuntimeError: If called with partioned variable regularization and
+ eager execution is enabled.
"""
- in_graph_mode = context.in_graph_mode()
- if in_graph_mode:
- existing_variables = set(tf_variables.global_variables())
+ # `init_graph` should point to the graph in which variable initialization
+ # will occur; it should be None if and only if initialization will take
+ # place in the eager context.
+ init_graph = None
+ if context.in_graph_mode():
+ default_graph = ops.get_default_graph()
+ if default_graph.building_function:
+ with ops.init_scope():
+ # Retrieve the variables from the graph into which variables
+ # will be lifted; if initialization ops will be lifted into
+ # the eager context, then there is nothing to retrieve, since variable
+ # collections are not supported when eager execution is enabled.
+ if context.in_graph_mode():
+ init_graph = ops.get_default_graph()
+ existing_variables = set(tf_variables.global_variables())
+ else:
+ # Initialization ops will not be lifted out of the default graph.
+ init_graph = default_graph
+ existing_variables = set(tf_variables.global_variables())
+
if dtype is None:
dtype = self.dtype or dtypes.float32
@@ -523,54 +552,51 @@ class Layer(object):
trainable=trainable and self.trainable,
partitioner=partitioner)
- if in_graph_mode:
- if (trainable and self.trainable
- and variable not in tf_variables.trainable_variables()):
- # A custom getter / variable scope overrode the trainable flag.
- trainable = False
+ if init_graph is not None: # pylint: disable=protected-access
+ # The variable was created and initialized in a graph.
+
if variable in existing_variables:
# To match the behavior of tf.get_variable(), we only apply
# regularization if the variable is newly created.
return variable
- if regularizer:
- def regularizer_factory():
- if context.in_graph_mode():
- with vs.variable_scope(scope, reuse=reuse,
- auxiliary_name_scope=False):
- with ops.name_scope(self._name_scope_name(scope)):
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
+ with init_graph.as_default():
+ trainable_variables = tf_variables.trainable_variables()
+ if (trainable and self.trainable and
+ variable not in trainable_variables):
+ # A custom getter / variable scope overrode the trainable flag.
+ trainable = False
+
+ if regularizer:
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ with ops.colocate_with(v.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ if regularization is not None:
+ self.add_loss(regularization)
else:
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when
- # executing eagerly.
- self._losses.append(lambda: regularizer(variable))
-
- if hasattr(self, '_defer_regularizers') and self._defer_regularizers:
- # _defer_regularizers exists and is set to True if `build` was
- # invoked in `__call__`: deferring regularizer construction
- # prevents the regularizer from being created in an `init_scope`.
- self._get_regularizer_factories().append(regularizer_factory)
- else:
- regularizer_factory()
+ with ops.colocate_with(variable.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(variable)
+ if regularization is not None:
+ self.add_loss(regularization)
+ elif regularizer: # and initialization took place in an eager context
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ raise RuntimeError(
+ 'Partitioned variable regularization is not yet '
+ 'supported when executing eagerly. File a feature request'
+ 'if this is important to you.')
+ # Save a zero-argument lambda which runs the regularizer on the
+ # variable, to be executed when `Layer.losses` is requested.
+ # This makes losses responsive to variable updates when executing
+ # eagerly.
+ #
+ # TODO(akshayka): Do the same for graphs as well, so that losses
+ # collected in a while_loop can be run outside its control flow
+ # context and so that losses won't be swallowed up by graph functions
+ # (i.e., `.losses()` should always create regularizers).
+ self._losses.append(lambda: regularizer(variable))
if trainable:
self._trainable_weights.append(variable)
@@ -670,15 +696,7 @@ class Layer(object):
except AttributeError:
pass
input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
-
- # Signal to `add_variable` that regularizer construction should be
- # deferred.
- self._defer_regularizers = True
- with ops.init_scope():
- self.build(input_shapes)
- # Create any regularizers added by `build`.
- self._maybe_create_variable_regularizers()
- self._defer_regularizers = False
+ self.build(input_shapes)
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
# the ones under tensorflow/python/keras). Hence we recompute this
@@ -1263,6 +1281,15 @@ class InputSpec(object):
self.min_ndim = min_ndim
self.axes = axes or {}
+ def __repr__(self):
+ spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
+ ('shape=' + str(self.shape)) if self.shape else '',
+ ('ndim=' + str(self.ndim)) if self.ndim else '',
+ ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
+ ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
+ ('axes=' + str(self.axes)) if self.axes else '']
+ return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
+
class Node(object):
"""A `Node` describes the connectivity between two layers.
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index ab1fa551e1..e8dba3cea3 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -819,8 +819,8 @@ def conv3d(inputs,
return layer.apply(inputs)
-class SeparableConv2D(Conv2D):
- """Depthwise separable 2D convolution.
+class _SeparableConv(_Conv):
+ """Abstract base layer for separable nD convolution.
This layer performs a depthwise convolution that acts separately on
channels, followed by a pointwise convolution that mixes channels.
@@ -829,12 +829,13 @@ class SeparableConv2D(Conv2D):
It then optionally applies an activation function to produce the final output.
Arguments:
+ rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
filters: Integer, the dimensionality of the output space (i.e. the number
of filters in the convolution).
- kernel_size: A tuple or list of 2 integers specifying the spatial
+ kernel_size: A tuple or list of integers specifying the spatial
dimensions of the filters. Can be a single integer to specify the same
value for all spatial dimensions.
- strides: A tuple or list of 2 positive integers specifying the strides
+ strides: A tuple or list of integers specifying the strides
of the convolution. Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any `stride` value != 1 is incompatible with specifying
@@ -843,9 +844,8 @@ class SeparableConv2D(Conv2D):
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, height, width)`.
-
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
dilation_rate: An integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
@@ -883,12 +883,14 @@ class SeparableConv2D(Conv2D):
name: A string, the name of the layer.
"""
- def __init__(self, filters,
+ def __init__(self,
+ rank,
+ filters,
kernel_size,
- strides=(1, 1),
+ strides=1,
padding='valid',
data_format='channels_last',
- dilation_rate=(1, 1),
+ dilation_rate=1,
depth_multiplier=1,
activation=None,
use_bias=True,
@@ -905,7 +907,8 @@ class SeparableConv2D(Conv2D):
trainable=True,
name=None,
**kwargs):
- super(SeparableConv2D, self).__init__(
+ super(_SeparableConv, self).__init__(
+ rank=rank,
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -920,7 +923,6 @@ class SeparableConv2D(Conv2D):
trainable=trainable,
name=name,
**kwargs)
- self.data_format = data_format
self.depth_multiplier = depth_multiplier
self.depthwise_initializer = depthwise_initializer
self.pointwise_initializer = pointwise_initializer
@@ -930,26 +932,21 @@ class SeparableConv2D(Conv2D):
self.pointwise_constraint = pointwise_constraint
def build(self, input_shape):
- if len(input_shape) < 4:
- raise ValueError('Inputs to `SeparableConv2D` should have rank 4. '
- 'Received input shape:', str(input_shape))
+ input_shape = tensor_shape.TensorShape(input_shape)
if self.data_format == 'channels_first':
channel_axis = 1
else:
- channel_axis = 3
- if input_shape[channel_axis] is None:
- raise ValueError('The channel dimension of the inputs to '
- '`SeparableConv2D` '
+ channel_axis = -1
+ if input_shape[channel_axis].value is None:
+ raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
- input_dim = int(input_shape[channel_axis])
- self.input_spec = base.InputSpec(ndim=4, axes={channel_axis: input_dim})
- depthwise_kernel_shape = (self.kernel_size[0],
- self.kernel_size[1],
- input_dim,
- self.depth_multiplier)
- pointwise_kernel_shape = (1, 1,
- self.depth_multiplier * input_dim,
- self.filters)
+ input_dim = input_shape[channel_axis].value
+ self.input_spec = base.InputSpec(ndim=self.rank + 2,
+ axes={channel_axis: input_dim})
+ depthwise_kernel_shape = self.kernel_size + (input_dim,
+ self.depth_multiplier)
+ pointwise_kernel_shape = (
+ 1,) * self.rank + (self.depth_multiplier * input_dim, self.filters)
self.depthwise_kernel = self.add_variable(
name='depthwise_kernel',
@@ -980,6 +977,264 @@ class SeparableConv2D(Conv2D):
self.built = True
def call(self, inputs):
+ raise NotImplementedError
+
+
+class SeparableConv1D(_SeparableConv):
+ """Depthwise separable 1D convolution.
+
+ This layer performs a depthwise convolution that acts separately on
+ channels, followed by a pointwise convolution that mixes channels.
+ If `use_bias` is True and a bias initializer is provided,
+ it adds a bias vector to the output.
+ It then optionally applies an activation function to produce the final output.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A single integer specifying the spatial
+ dimensions of the filters.
+ strides: A single integer specifying the strides
+ of the convolution.
+ Specifying any `stride` value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: A single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ depthwise_initializer: An initializer for the depthwise convolution kernel.
+ pointwise_initializer: An initializer for the pointwise convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ depthwise_regularizer: Optional regularizer for the depthwise
+ convolution kernel.
+ pointwise_regularizer: Optional regularizer for the pointwise
+ convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Optional regularizer function for the output.
+ depthwise_constraint: Optional projection function to be applied to the
+ depthwise kernel after being updated by an `Optimizer` (e.g. used for
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are
+ not safe to use when doing asynchronous distributed training.
+ pointwise_constraint: Optional projection function to be applied to the
+ pointwise kernel after being updated by an `Optimizer`.
+ bias_constraint: Optional projection function to be applied to the
+ bias after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self, filters,
+ kernel_size,
+ strides=1,
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=1,
+ depth_multiplier=1,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer=None,
+ pointwise_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ depthwise_regularizer=None,
+ pointwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ pointwise_constraint=None,
+ bias_constraint=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(SeparableConv1D, self).__init__(
+ rank=1,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ depth_multiplier=depth_multiplier,
+ activation=activation,
+ use_bias=use_bias,
+ depthwise_initializer=depthwise_initializer,
+ pointwise_initializer=pointwise_initializer,
+ bias_initializer=bias_initializer,
+ depthwise_regularizer=depthwise_regularizer,
+ pointwise_regularizer=pointwise_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ depthwise_constraint=depthwise_constraint,
+ pointwise_constraint=pointwise_constraint,
+ bias_constraint=bias_constraint,
+ trainable=trainable,
+ name=name,
+ **kwargs)
+
+ def call(self, inputs):
+ if self.data_format == 'channels_last':
+ strides = (1, 1) + self.strides + (1,)
+ spatial_start_dim = 1
+ else:
+ strides = (1, 1, 1) + self.strides
+ spatial_start_dim = 2
+
+ # Explicitly broadcast inputs and kernels to 4D.
+ # TODO(fchollet): refactor when a native separable_conv1d op is available.
+ inputs = array_ops.expand_dims(inputs, spatial_start_dim)
+ depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0)
+ pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
+ dilation_rate = (1,) + self.dilation_rate
+
+ outputs = nn.separable_conv2d(
+ inputs,
+ depthwise_kernel,
+ pointwise_kernel,
+ strides=strides,
+ padding=self.padding.upper(),
+ rate=dilation_rate,
+ data_format=utils.convert_data_format(self.data_format, ndim=4))
+
+ if self.use_bias:
+ outputs = nn.bias_add(
+ outputs,
+ self.bias,
+ data_format=utils.convert_data_format(self.data_format, ndim=4))
+
+ outputs = array_ops.squeeze(outputs, [spatial_start_dim])
+
+ if self.activation is not None:
+ return self.activation(outputs)
+ return outputs
+
+
+class SeparableConv2D(_SeparableConv):
+ """Depthwise separable 2D convolution.
+
+ This layer performs a depthwise convolution that acts separately on
+ channels, followed by a pointwise convolution that mixes channels.
+ If `use_bias` is True and a bias initializer is provided,
+ it adds a bias vector to the output.
+ It then optionally applies an activation function to produce the final output.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A tuple or list of 2 integers specifying the spatial
+ dimensions of the filters. Can be a single integer to specify the same
+ value for all spatial dimensions.
+ strides: A tuple or list of 2 positive integers specifying the strides
+ of the convolution. Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any `stride` value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
+
+ dilation_rate: An integer or tuple/list of 2 integers, specifying
+ the dilation rate to use for dilated convolution.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ depthwise_initializer: An initializer for the depthwise convolution kernel.
+ pointwise_initializer: An initializer for the pointwise convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ depthwise_regularizer: Optional regularizer for the depthwise
+ convolution kernel.
+ pointwise_regularizer: Optional regularizer for the pointwise
+ convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Optional regularizer function for the output.
+ depthwise_constraint: Optional projection function to be applied to the
+ depthwise kernel after being updated by an `Optimizer` (e.g. used for
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are
+ not safe to use when doing asynchronous distributed training.
+ pointwise_constraint: Optional projection function to be applied to the
+ pointwise kernel after being updated by an `Optimizer`.
+ bias_constraint: Optional projection function to be applied to the
+ bias after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self, filters,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=(1, 1),
+ depth_multiplier=1,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer=None,
+ pointwise_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ depthwise_regularizer=None,
+ pointwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ pointwise_constraint=None,
+ bias_constraint=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(SeparableConv2D, self).__init__(
+ rank=2,
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ depth_multiplier=depth_multiplier,
+ activation=activation,
+ use_bias=use_bias,
+ depthwise_initializer=depthwise_initializer,
+ pointwise_initializer=pointwise_initializer,
+ bias_initializer=bias_initializer,
+ depthwise_regularizer=depthwise_regularizer,
+ pointwise_regularizer=pointwise_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ depthwise_constraint=depthwise_constraint,
+ pointwise_constraint=pointwise_constraint,
+ bias_constraint=bias_constraint,
+ trainable=trainable,
+ name=name,
+ **kwargs)
+
+ def call(self, inputs):
# Apply the actual ops.
if self.data_format == 'channels_last':
strides = (1,) + self.strides + (1,)
@@ -1004,25 +1259,121 @@ class SeparableConv2D(Conv2D):
return self.activation(outputs)
return outputs
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.data_format == 'channels_first':
- rows = input_shape[2]
- cols = input_shape[3]
- else:
- rows = input_shape[1]
- cols = input_shape[2]
- rows = utils.conv_output_length(rows, self.kernel_size[0],
- self.padding, self.strides[0])
- cols = utils.conv_output_length(cols, self.kernel_size[1],
- self.padding, self.strides[1])
- if self.data_format == 'channels_first':
- return tensor_shape.TensorShape(
- [input_shape[0], self.filters, rows, cols])
- else:
- return tensor_shape.TensorShape(
- [input_shape[0], rows, cols, self.filters])
+def separable_conv1d(inputs,
+ filters,
+ kernel_size,
+ strides=1,
+ padding='valid',
+ data_format='channels_last',
+ dilation_rate=1,
+ depth_multiplier=1,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer=None,
+ pointwise_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ depthwise_regularizer=None,
+ pointwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ pointwise_constraint=None,
+ bias_constraint=None,
+ trainable=True,
+ name=None,
+ reuse=None):
+ """Functional interface for the depthwise separable 1D convolution layer.
+
+ This layer performs a depthwise convolution that acts separately on
+ channels, followed by a pointwise convolution that mixes channels.
+ If `use_bias` is True and a bias initializer is provided,
+ it adds a bias vector to the output.
+ It then optionally applies an activation function to produce the final output.
+
+ Arguments:
+ inputs: Input tensor.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A single integer specifying the spatial
+ dimensions of the filters.
+ strides: A single integer specifying the strides
+ of the convolution.
+ Specifying any `stride` value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: A single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ depthwise_initializer: An initializer for the depthwise convolution kernel.
+ pointwise_initializer: An initializer for the pointwise convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ depthwise_regularizer: Optional regularizer for the depthwise
+ convolution kernel.
+ pointwise_regularizer: Optional regularizer for the pointwise
+ convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Optional regularizer function for the output.
+ depthwise_constraint: Optional projection function to be applied to the
+ depthwise kernel after being updated by an `Optimizer` (e.g. used for
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are
+ not safe to use when doing asynchronous distributed training.
+ pointwise_constraint: Optional projection function to be applied to the
+ pointwise kernel after being updated by an `Optimizer`.
+ bias_constraint: Optional projection function to be applied to the
+ bias after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if eager execution is enabled.
+ """
+ layer = SeparableConv1D(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ depth_multiplier=depth_multiplier,
+ activation=activation,
+ use_bias=use_bias,
+ depthwise_initializer=depthwise_initializer,
+ pointwise_initializer=pointwise_initializer,
+ bias_initializer=bias_initializer,
+ depthwise_regularizer=depthwise_regularizer,
+ pointwise_regularizer=pointwise_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ depthwise_constraint=depthwise_constraint,
+ pointwise_constraint=pointwise_constraint,
+ bias_constraint=bias_constraint,
+ trainable=trainable,
+ name=name,
+ _reuse=reuse,
+ _scope=name)
+ return layer.apply(inputs)
def separable_conv2d(inputs,
@@ -1553,6 +1904,7 @@ class Conv3DTranspose(Conv3D):
dtype=self.dtype)
else:
self.bias = None
+ self.built = True
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
@@ -1623,6 +1975,8 @@ class Conv3DTranspose(Conv3D):
if self.use_bias:
outputs_shape = outputs.shape.as_list()
+ if outputs_shape[0] is None:
+ outputs_shape[0] = -1
if self.data_format == 'channels_first':
outputs_4d = array_ops.reshape(outputs, [
outputs_shape[0], outputs_shape[1],
@@ -1656,11 +2010,11 @@ class Conv3DTranspose(Conv3D):
output_shape[c_axis] = self.filters
output_shape[d_axis] = utils.deconv_output_length(
- output_shape[d_axis], stride_d, kernel_d, self.padding)
+ output_shape[d_axis], kernel_d, self.padding, stride_d)
output_shape[h_axis] = utils.deconv_output_length(
- output_shape[h_axis], stride_h, kernel_h, self.padding)
+ output_shape[h_axis], kernel_h, self.padding, stride_h)
output_shape[w_axis] = utils.deconv_output_length(
- output_shape[w_axis], stride_w, kernel_w, self.padding)
+ output_shape[w_axis], kernel_w, self.padding, stride_w)
return tensor_shape.TensorShape(output_shape)
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index e41eb5c32f..160e732b67 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -327,6 +327,168 @@ class ConvTest(test.TestCase):
@test_util.with_c_api
+class SeparableConv1DTest(test.TestCase):
+
+ def testInvalidDataFormat(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'data_format'):
+ conv_layers.separable_conv1d(data, 32, 3, data_format='invalid')
+
+ def testInvalidStrides(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'strides'):
+ conv_layers.separable_conv1d(data, 32, 3, strides=(1, 2))
+
+ with self.assertRaisesRegexp(ValueError, 'strides'):
+ conv_layers.separable_conv1d(data, 32, 3, strides=None)
+
+ def testInvalidKernelSize(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ with self.assertRaisesRegexp(ValueError, 'kernel_size'):
+ conv_layers.separable_conv1d(data, 32, (1, 2))
+
+ with self.assertRaisesRegexp(ValueError, 'kernel_size'):
+ conv_layers.separable_conv1d(data, 32, None)
+
+ def testCreateSeparableConv1D(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ layer = conv_layers.SeparableConv1D(32, 3, activation=nn_ops.relu)
+ output = layer.apply(data)
+ self.assertEqual(output.op.name, 'separable_conv1d/Relu')
+ self.assertEqual(output.get_shape().as_list(), [5, length - 2, 32])
+ self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 1])
+ self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 4, 32])
+ self.assertEqual(layer.bias.get_shape().as_list(), [32])
+
+ def testCreateSeparableConv1DDepthMultiplier(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ layer = conv_layers.SeparableConv1D(32, 3, depth_multiplier=2)
+ output = layer.apply(data)
+ self.assertEqual(output.get_shape().as_list(), [5, length - 2, 32])
+ self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 2])
+ self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 8, 32])
+ self.assertEqual(layer.bias.get_shape().as_list(), [32])
+
+ def testCreateSeparableConv1DChannelsFirst(self):
+ length = 9
+ data = random_ops.random_uniform((5, 4, length))
+ layer = conv_layers.SeparableConv1D(32, 3, data_format='channels_first')
+ output = layer.apply(data)
+ self.assertEqual(output.get_shape().as_list(), [5, 32, length - 2])
+ self.assertEqual(layer.depthwise_kernel.get_shape().as_list(), [3, 4, 1])
+ self.assertEqual(layer.pointwise_kernel.get_shape().as_list(), [1, 4, 32])
+ self.assertEqual(layer.bias.get_shape().as_list(), [32])
+
+ def testSeparableConv1DPaddingSame(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 32), seed=1)
+ layer = conv_layers.SeparableConv1D(
+ 64, length, padding='same')
+ output = layer.apply(data)
+ self.assertEqual(output.get_shape().as_list(), [5, length, 64])
+
+ def testCreateSeparableConv1DWithStrides(self):
+ length = 10
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ layer = conv_layers.SeparableConv1D(32, 3, strides=2, padding='same')
+ output = layer.apply(data)
+ self.assertEqual(output.get_shape().as_list(), [5, length // 2, 32])
+
+ def testCreateSeparableConv1DWithStridesChannelsFirst(self):
+ data_format = 'channels_first'
+ length = 10
+ data = random_ops.random_uniform((5, 3, length), seed=1)
+ layer = conv_layers.SeparableConv1D(
+ 32, 3, strides=2, padding='same', data_format=data_format)
+ output = layer.apply(data)
+ self.assertEqual(output.get_shape().as_list(), [5, 32, length // 2])
+
+ def testFunctionalConv1DReuse(self):
+ length = 10
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ conv_layers.separable_conv1d(data, 32, 3, name='sepconv1')
+ self.assertEqual(len(variables.trainable_variables()), 3)
+ conv_layers.separable_conv1d(data, 32, 3, name='sepconv1', reuse=True)
+ self.assertEqual(len(variables.trainable_variables()), 3)
+
+ def testFunctionalConv1DReuseFromScope(self):
+ with variable_scope.variable_scope('scope'):
+ length = 10
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ conv_layers.separable_conv1d(data, 32, 3, name='sepconv1')
+ self.assertEqual(len(variables.trainable_variables()), 3)
+ with variable_scope.variable_scope('scope', reuse=True):
+ conv_layers.separable_conv1d(data, 32, 3, name='sepconv1')
+ self.assertEqual(len(variables.trainable_variables()), 3)
+
+ def testFunctionalConv1DNoReuse(self):
+ length = 10
+ data = random_ops.random_uniform((5, length, 3), seed=1)
+ conv_layers.separable_conv1d(data, 32, 3)
+ self.assertEqual(len(variables.trainable_variables()), 3)
+ conv_layers.separable_conv1d(data, 32, 3)
+ self.assertEqual(len(variables.trainable_variables()), 6)
+
+ def testSeparableConv1DDepthwiseRegularizer(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ reg = lambda x: 0.1 * math_ops.reduce_sum(x)
+ layer = conv_layers.SeparableConv1D(32, 3, depthwise_regularizer=reg)
+ layer.apply(data)
+ loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(loss_keys), 1)
+ self.assertEqual(layer.losses, loss_keys)
+
+ def testSeparableConv1DPointwiseRegularizer(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ reg = lambda x: 0.1 * math_ops.reduce_sum(x)
+ layer = conv_layers.SeparableConv1D(32, 3, pointwise_regularizer=reg)
+ layer.apply(data)
+ loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(loss_keys), 1)
+ self.assertEqual(layer.losses, loss_keys)
+
+ def testSeparableConv1DBiasRegularizer(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ reg = lambda x: 0.1 * math_ops.reduce_sum(x)
+ layer = conv_layers.SeparableConv1D(32, 3, bias_regularizer=reg)
+ layer.apply(data)
+ loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+ self.assertEqual(len(loss_keys), 1)
+ self.assertEqual(layer.losses, loss_keys)
+
+ def testSeparableConv1DNoBias(self):
+ length = 9
+ data = random_ops.random_uniform((5, length, 4))
+ layer = conv_layers.SeparableConv1D(
+ 32, 3, activation=nn_ops.relu, use_bias=False)
+ output = layer.apply(data)
+ self.assertEqual(output.op.name, 'separable_conv1d/Relu')
+ self.assertEqual(layer.bias, None)
+
+ def testConstraints(self):
+ d_constraint = lambda x: x / math_ops.reduce_sum(x)
+ p_constraint = lambda x: x / math_ops.reduce_sum(x)
+ b_constraint = lambda x: x / math_ops.reduce_max(x)
+ layer = conv_layers.SeparableConv1D(2, 3,
+ depthwise_constraint=d_constraint,
+ pointwise_constraint=p_constraint,
+ bias_constraint=b_constraint)
+ inputs = random_ops.random_uniform((5, 3, 5), seed=1)
+ layer(inputs)
+ self.assertEqual(layer.depthwise_constraint, d_constraint)
+ self.assertEqual(layer.pointwise_constraint, p_constraint)
+ self.assertEqual(layer.bias_constraint, b_constraint)
+
+
+@test_util.with_c_api
class SeparableConv2DTest(test.TestCase):
def testInvalidDataFormat(self):
diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py
index 0a52b1e8d9..1555846efd 100644
--- a/tensorflow/python/layers/layers.py
+++ b/tensorflow/python/layers/layers.py
@@ -22,6 +22,7 @@
@@Conv1D
@@Conv2D
@@Conv3D
+@@SeparableConv1D
@@SeparableConv2D
@@Conv2DTranspose
@@Conv3DTranspose
@@ -43,6 +44,7 @@
@@conv1d
@@conv2d
@@conv3d
+@@separable_conv1d
@@separable_conv2d
@@conv2d_transpose
@@conv3d_transpose
@@ -78,6 +80,7 @@ from tensorflow.python.layers.core import dropout
from tensorflow.python.layers.core import flatten
# Convolutional layers.
+from tensorflow.python.layers.convolutional import SeparableConv1D
from tensorflow.python.layers.convolutional import SeparableConv2D
from tensorflow.python.layers.convolutional import SeparableConvolution2D
from tensorflow.python.layers.convolutional import Conv2DTranspose
@@ -91,6 +94,7 @@ from tensorflow.python.layers.convolutional import Convolution2D
from tensorflow.python.layers.convolutional import Conv3D
from tensorflow.python.layers.convolutional import Convolution3D
+from tensorflow.python.layers.convolutional import separable_conv1d
from tensorflow.python.layers.convolutional import separable_conv2d
from tensorflow.python.layers.convolutional import conv2d_transpose
from tensorflow.python.layers.convolutional import conv3d_transpose
diff --git a/tensorflow/python/layers/maxout.py b/tensorflow/python/layers/maxout.py
index ed048845a0..20ce6c9770 100644
--- a/tensorflow/python/layers/maxout.py
+++ b/tensorflow/python/layers/maxout.py
@@ -31,15 +31,18 @@ from tensorflow.python.layers import base
def maxout(inputs, num_units, axis=-1, name=None):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
- "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville,
+ "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron
+ Courville,
Yoshua Bengio
- Usually the operation is performed in the filter/channel dimension. This can also be
+ Usually the operation is performed in the filter/channel dimension. This can
+ also be
used after fully-connected layers to reduce number of features.
Arguments:
inputs: Tensor input
- num_units: Specifies how many features will remain after maxout in the `axis` dimension
+ num_units: Specifies how many features will remain after maxout in the `axis`
+ dimension
(usually channel). This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
@@ -57,15 +60,18 @@ def maxout(inputs, num_units, axis=-1, name=None):
class MaxOut(base.Layer):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
- "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
+ "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron
+ Courville, Yoshua
Bengio
- Usually the operation is performed in the filter/channel dimension. This can also be
+ Usually the operation is performed in the filter/channel dimension. This can
+ also be
used after fully-connected layers to reduce number of features.
Arguments:
inputs: Tensor input
- num_units: Specifies how many features will remain after maxout in the `axis` dimension
+ num_units: Specifies how many features will remain after maxout in the
+ `axis` dimension
(usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
@@ -79,13 +85,8 @@ class MaxOut(base.Layer):
ValueError: if num_units is not multiple of number of features.
"""
- def __init__(self,
- num_units,
- axis=-1,
- name=None,
- **kwargs):
- super(MaxOut, self).__init__(
- name=name, trainable=False, **kwargs)
+ def __init__(self, num_units, axis=-1, name=None, **kwargs):
+ super(MaxOut, self).__init__(name=name, trainable=False, **kwargs)
self.axis = axis
self.num_units = num_units
@@ -95,8 +96,8 @@ class MaxOut(base.Layer):
num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError('number of features({}) is not '
- 'a multiple of num_units({})'
- .format(num_channels, self.num_units))
+ 'a multiple of num_units({})'.format(
+ num_channels, self.num_units))
shape[self.axis] = -1
shape += [num_channels // self.num_units]
@@ -104,6 +105,7 @@ class MaxOut(base.Layer):
for i in range(len(shape)):
if shape[i] is None:
shape[i] = gen_array_ops.shape(inputs)[i]
- outputs = math_ops.reduce_max(gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
+ outputs = math_ops.reduce_max(
+ gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
return outputs
diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py
index ade57da411..0a5dd57621 100644
--- a/tensorflow/python/layers/network.py
+++ b/tensorflow/python/layers/network.py
@@ -575,6 +575,11 @@ class GraphNetwork(base.Layer):
raise ValueError('No such layer: ' + name)
@property
+ def stateful(self):
+ return any([(hasattr(layer, 'stateful') and layer.stateful)
+ for layer in self.layers])
+
+ @property
def updates(self):
"""Retrieve the network's updates.
@@ -586,6 +591,8 @@ class GraphNetwork(base.Layer):
Returns:
A list of update ops.
"""
+ if not self.trainable and not self.stateful:
+ return []
updates = []
for layer in self.layers:
if hasattr(layer, 'updates'):
diff --git a/tensorflow/python/layers/pooling.py b/tensorflow/python/layers/pooling.py
index c6bd7aae07..ab06a3a408 100644
--- a/tensorflow/python/layers/pooling.py
+++ b/tensorflow/python/layers/pooling.py
@@ -63,14 +63,18 @@ class _Pooling1D(base.Layer):
def call(self, inputs):
# There is no TF op for 1D pooling, hence we make the inputs 4D.
if self.data_format == 'channels_last':
- inputs = array_ops.expand_dims(inputs, 2)
- pool_shape = (1,) + self.pool_size + (1, 1)
- strides = (1,) + self.strides + (1, 1)
- data_format = 'NHWC'
- else:
+ # input is NWC, make it NHWC
inputs = array_ops.expand_dims(inputs, 1)
+ # pool on the W dim
pool_shape = (1, 1) + self.pool_size + (1,)
strides = (1, 1) + self.strides + (1,)
+ data_format = 'NHWC'
+ else:
+ # input is NCW, make it NCHW
+ inputs = array_ops.expand_dims(inputs, 2)
+ # pool on the W dim
+ pool_shape = (1, 1, 1) + self.pool_size
+ strides = (1, 1, 1) + self.strides
data_format = 'NCHW'
outputs = self.pool_function(
@@ -81,9 +85,9 @@ class _Pooling1D(base.Layer):
data_format=data_format)
if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 2)
- else:
return array_ops.squeeze(outputs, 1)
+ else:
+ return array_ops.squeeze(outputs, 2)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py
index 589fee5f71..7533674e5a 100644
--- a/tensorflow/python/layers/pooling_test.py
+++ b/tensorflow/python/layers/pooling_test.py
@@ -96,33 +96,41 @@ class PoolingTest(test.TestCase):
def testCreateMaxPooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.MaxPooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateAveragePooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.AveragePooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateMaxPooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.MaxPooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateAveragePooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.AveragePooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateMaxPooling3D(self):
depth, height, width = 6, 7, 9
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py
index e8be347799..7407d9a7b3 100644
--- a/tensorflow/python/layers/utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -81,7 +81,7 @@ def normalize_tuple(value, n, name):
for single_value in value_tuple:
try:
int(single_value)
- except ValueError:
+ except (ValueError, TypeError):
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value) + ' '
'including element ' + str(single_value) + ' of type' +
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index 985a11272c..09d4b01fa4 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -25,6 +25,7 @@ import numpy as np
# pylint: disable=unused-import,g-bad-import-order
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
@@ -160,6 +161,24 @@ class Bfloat16Test(test.TestCase):
for w in self.float_values():
self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
+ def testNan(self):
+ a = np.isnan(bfloat16(float("nan")))
+ self.assertTrue(a)
+ np.testing.assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
+
+ a = np.array(
+ [bfloat16(1.34375),
+ bfloat16(1.4375),
+ bfloat16(float("nan"))],
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ b = np.array(
+ [bfloat16(1.3359375),
+ bfloat16(1.4375),
+ bfloat16(float("nan"))],
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ np.testing.assert_allclose(
+ a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
+
class Bfloat16NumPyTest(test.TestCase):
diff --git a/tensorflow/python/lib/core/ndarray_tensor.h b/tensorflow/python/lib/core/ndarray_tensor.h
index 5172d504bd..b2cd4133ca 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.h
+++ b/tensorflow/python/lib/core/ndarray_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
// Must be included first.
#include "tensorflow/python/lib/core/numpy.h"
@@ -45,4 +45,4 @@ Status TensorToNdarray(const Tensor& t, PyObject** ret);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
+#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_H_
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index dc56b39486..d3bfa0ee33 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
+
#include <Python.h>
namespace tensorflow {
@@ -141,7 +142,8 @@ bool IsSingleNone(PyObject* obj) {
return false;
}
std::array<npy_intp, 0> indices;
- char* item_ptr = static_cast<char*>(PyArray_GetPtr(array_obj, indices.data()));
+ char* item_ptr =
+ static_cast<char*>(PyArray_GetPtr(array_obj, indices.data()));
PyObject* item = PyArray_GETITEM(array_obj, item_ptr);
CHECK(item);
return item == Py_None;
@@ -301,13 +303,22 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) {
#if PY_MAJOR_VERSION >= 3
el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size);
- if (!el) {
+#else
+ el = nullptr;
+ if (PyUnicode_Check(input_data[i])) {
+ PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]);
+ if (unicode) {
+ if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) {
+ Py_DECREF(unicode);
+ el = nullptr;
+ }
+ }
+ }
#endif
+ if (!el) {
return errors::Unimplemented("Unsupported object type ",
input_data[i]->ob_type->tp_name);
-#if PY_MAJOR_VERSION >= 3
}
-#endif
}
tflat(i) = string(el, el_size);
}
diff --git a/tensorflow/python/lib/core/py_seq_tensor.h b/tensorflow/python/lib/core/py_seq_tensor.h
index 6dc4d9c777..c6e5080c62 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.h
+++ b/tensorflow/python/lib/core/py_seq_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
#include <Python.h>
@@ -34,4 +34,4 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_SEQ_TENSOR_H_
diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h
index 80db840aeb..32d2868886 100644
--- a/tensorflow/python/lib/core/safe_ptr.h
+++ b/tensorflow/python/lib/core/safe_ptr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
#include <memory>
@@ -66,4 +66,4 @@ Safe_TF_StatusPtr make_safe(TF_Status* status);
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
+#endif // TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 55cae0bcbf..c9292184e6 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Gradients for operators defined in array_ops.py."""
from __future__ import absolute_import
@@ -131,8 +130,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
# extract the size of each input along the concat dimension
sizes = array_ops.squeeze(
array_ops.slice(
- array_ops.stack(
- sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
+ array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
+ [1, -1]))
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
@@ -167,8 +166,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
new_values = array_ops.slice(
grad.values, begin,
array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
- out_grads.append(
- ops.IndexedSlices(new_values, grad.indices, size))
+ out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
else:
@@ -178,30 +176,33 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
for size in sizes:
size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
if size_concat_dim.dtype != grad.indices.dtype:
- size_concat_dim = math_ops.cast(size_concat_dim,
- dtype=grad.indices.dtype)
+ size_concat_dim = math_ops.cast(
+ size_concat_dim, dtype=grad.indices.dtype)
end = start + size_concat_dim
# Compute the 1-D Tensor of indices relevant for this input.
indices_to_select = array_ops.squeeze(
- array_ops.where(math_ops.logical_and(grad.indices >= start,
- grad.indices < end)),
+ array_ops.where(
+ math_ops.logical_and(grad.indices >= start,
+ grad.indices < end)),
squeeze_dims=[1])
new_indices = array_ops.gather(grad.indices, indices_to_select) - start
new_values = array_ops.gather(grad.values, indices_to_select)
- out_grads.append(
- ops.IndexedSlices(new_values, new_indices, size))
+ out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
start = end
else:
raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
- return (out_grads + [None] if end_value_index <= dim_index
- else [None] + out_grads)
+ return (out_grads + [None]
+ if end_value_index <= dim_index else [None] + out_grads)
@ops.RegisterGradient("Concat")
def _ConcatGrad(op, grad):
return _ConcatGradHelper(
- op, grad, start_value_index=1, end_value_index=len(op.inputs),
+ op,
+ grad,
+ start_value_index=1,
+ end_value_index=len(op.inputs),
dim_index=0)
@@ -287,9 +288,13 @@ def _SplitGrad(op, *grads):
@ops.RegisterGradient("SplitV")
def _SplitVGrad(op, *grads):
returnval = array_ops.concat(list(grads), op.inputs[2])
- returnval = [returnval] + [None,] * (len(op.inputs) - 1)
+ returnval = [returnval] + [
+ None,
+ ] * (
+ len(op.inputs) - 1)
return returnval
+
ops.NotDifferentiable("Const")
@@ -334,9 +339,9 @@ def _MatrixSetDiagGrad(op, grad):
matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
min_dim = math_ops.reduce_min(matrix_shape)
diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
- grad_input = array_ops.matrix_set_diag(
- grad, array_ops.zeros(
- diag_shape, dtype=grad.dtype))
+ grad_input = array_ops.matrix_set_diag(grad,
+ array_ops.zeros(
+ diag_shape, dtype=grad.dtype))
grad_diag = array_ops.matrix_diag_part(grad)
return (grad_input, grad_diag)
@@ -444,8 +449,8 @@ def _GatherV2Grad(op, grad):
values_transpose = array_ops.transpose(values, transpose_dims)
num_segments = params_shape[axis]
- params_grad = math_ops.unsorted_segment_sum(
- values_transpose, indices, num_segments)
+ params_grad = math_ops.unsorted_segment_sum(values_transpose, indices,
+ num_segments)
# Inverts the above transpose by moving dimension 0 back to its original
# position.
@@ -536,13 +541,10 @@ def _ConjugateTransposeGrad(op, grad):
ops.NotDifferentiable("Shape")
-
ops.NotDifferentiable("ShapeN")
-
ops.NotDifferentiable("Rank")
-
ops.NotDifferentiable("Size")
@@ -590,6 +592,7 @@ def _PadGrad(op, grad):
else:
return x_grad, None
+
ops.RegisterGradient("Pad")(_PadGrad)
ops.RegisterGradient("PadV2")(_PadGrad)
@@ -625,30 +628,34 @@ def _ReverseV2Grad(op, grad):
def _SpaceToBatchGrad(op, grad):
# Its gradient is the opposite op: BatchToSpace.
block_size = op.get_attr("block_size")
- return [array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size),
- None]
+ return [
+ array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
+ ]
@ops.RegisterGradient("SpaceToBatchND")
def _SpaceToBatchNDGrad(op, grad):
# Its gradient is the opposite op: BatchToSpaceND.
- return [array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]),
- None, None]
+ return [
+ array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
+ ]
@ops.RegisterGradient("BatchToSpace")
def _BatchToSpaceGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatch.
block_size = op.get_attr("block_size")
- return [array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size),
- None]
+ return [
+ array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
+ ]
@ops.RegisterGradient("BatchToSpaceND")
def _BatchToSpaceNDGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatchND.
- return [array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]),
- None, None]
+ return [
+ array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
+ ]
@ops.RegisterGradient("SpaceToDepth")
@@ -712,30 +719,28 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
def _ExtractImagePatchesGrad(op, grad):
batch_size, rows_in, cols_in, channels = [
- dim.value for dim in op.inputs[0].get_shape()
+ dim.value for dim in op.inputs[0].get_shape()
]
input_bhwc = array_ops.shape(op.inputs[0])
batch_size = input_bhwc[0]
channels = input_bhwc[3]
- _, rows_out, cols_out, _ = [
- dim.value for dim in op.outputs[0].get_shape()
- ]
- _, ksize_r, ksize_c, _ = op.get_attr('ksizes')
- _, stride_r, stride_h, _ = op.get_attr('strides')
- _, rate_r, rate_c, _ = op.get_attr('rates')
- padding = op.get_attr('padding')
+ _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()]
+ _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
+ _, stride_r, stride_h, _ = op.get_attr("strides")
+ _, rate_r, rate_c, _ = op.get_attr("rates")
+ padding = op.get_attr("padding")
ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)
- if padding == b'SAME':
+ if padding == b"SAME":
rows_out = int(ceil(rows_in / stride_r))
cols_out = int(ceil(cols_in / stride_h))
pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2
- elif padding == b'VALID':
+ elif padding == b"VALID":
rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
@@ -744,10 +749,9 @@ def _ExtractImagePatchesGrad(op, grad):
pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)
grad_expanded = array_ops.transpose(
- array_ops.reshape(grad, (batch_size, rows_out,
- cols_out, ksize_r, ksize_c, channels)),
- (1, 2, 3, 4, 0, 5)
- )
+ array_ops.reshape(
+ grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
+ (1, 2, 3, 4, 0, 5))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
row_steps = range(0, rows_out * stride_r, stride_r)
@@ -759,29 +763,21 @@ def _ExtractImagePatchesGrad(op, grad):
r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff
- idx.extend([(r * (cols_in) + c,
- i * (cols_out * ksize_r * ksize_c) +
- j * (ksize_r * ksize_c) +
- ri * (ksize_c) + ci)
+ idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j *
+ (ksize_r * ksize_c) + ri * (ksize_c) + ci)
for (ri, r) in enumerate(range(r_low, r_high, rate_r))
for (ci, c) in enumerate(range(c_low, c_high, rate_c))
- if 0 <= r and r < rows_in and 0 <= c and c < cols_in
- ])
+ if 0 <= r and r < rows_in and 0 <= c and c < cols_in])
- sp_shape = (rows_in * cols_in,
- rows_out * cols_out * ksize_r * ksize_c)
+ sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c)
sp_mat = sparse_tensor.SparseTensor(
- array_ops.constant(idx, dtype=ops.dtypes.int64),
- array_ops.ones((len(idx),), dtype=ops.dtypes.float32),
- sp_shape
- )
+ array_ops.constant(idx, dtype=ops.dtypes.int64),
+ array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape)
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
- grad_out = array_ops.reshape(
- jac, (rows_in, cols_in, batch_size, channels)
- )
+ grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
return [grad_out]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 78b4a7101c..24a0c18619 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -103,16 +103,19 @@ from tensorflow.python.ops import gen_math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
# Used for slicing to specify a new 1 size dimension
newaxis = None
+tf_export("newaxis").export_constant(__name__, "newaxis")
# We override the 'slice' for the "slice" op, so we keep python's
# existing 'slice' for later use in this module.
_BaseSlice = slice
+@tf_export("identity")
def identity(input, name=None): # pylint: disable=redefined-builtin
r"""Return a tensor with the same shape and contents as input.
@@ -135,6 +138,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,protected-access
+@tf_export("expand_dims")
def expand_dims(input, axis=None, name=None, dim=None):
"""Inserts a dimension of 1 into a tensor's shape.
@@ -211,6 +215,7 @@ listdiff.__doc__ = gen_array_ops._list_diff.__doc__ + "\n" + listdiff.__doc__
# pylint: disable=undefined-variable,protected-access
+@tf_export("setdiff1d")
def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
return gen_array_ops._list_diff(x, y, index_dtype, name)
@@ -220,6 +225,7 @@ setdiff1d.__doc__ = gen_array_ops._list_diff.__doc__
# pylint: enable=protected-access
+@tf_export("broadcast_dynamic_shape")
def broadcast_dynamic_shape(shape_x, shape_y):
# pylint: disable=protected-access
"""Returns the broadcasted dynamic shape between `shape_x` and `shape_y`.
@@ -235,6 +241,7 @@ def broadcast_dynamic_shape(shape_x, shape_y):
# pylint: enable=protected-access
+@tf_export("broadcast_static_shape")
def broadcast_static_shape(shape_x, shape_y):
"""Returns the broadcasted static shape between `shape_x` and `shape_y`.
@@ -251,6 +258,7 @@ def broadcast_static_shape(shape_x, shape_y):
return common_shapes.broadcast_shape(shape_x, shape_y)
+@tf_export("shape")
def shape(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
@@ -304,6 +312,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
return gen_array_ops.shape(input, name=name, out_type=out_type)
+@tf_export("shape_n")
def shape_n(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns shape of tensors.
@@ -330,6 +339,7 @@ def shape_n(input, out_type=dtypes.int32, name=None):
return output
+@tf_export("size")
def size(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the size of a tensor.
@@ -387,6 +397,7 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
return gen_array_ops.size(input, name=name, out_type=out_type)
+@tf_export("rank")
def rank(input, name=None):
# pylint: disable=redefined-builtin
"""Returns the rank of a tensor.
@@ -577,6 +588,7 @@ def _slice_helper(tensor, slice_spec, var=None):
# pylint: disable=undefined-variable,protected-access,redefined-outer-name
+@tf_export("slice")
def slice(input_, begin, size, name=None):
# pylint: disable=redefined-builtin
"""Extracts a slice from a tensor.
@@ -629,6 +641,7 @@ def slice(input_, begin, size, name=None):
# pylint: disable=invalid-name
+@tf_export("strided_slice")
def strided_slice(input_,
begin,
end,
@@ -817,6 +830,7 @@ def _SliceHelperVar(var, slice_spec):
ops.Tensor._override_operator("__getitem__", _slice_helper)
+@tf_export("parallel_stack")
def parallel_stack(values, name="parallel_stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel.
@@ -867,6 +881,7 @@ def parallel_stack(values, name="parallel_stack"):
[expand_dims(value, 0) for value in values], shape=output_shape)
+@tf_export("stack")
def stack(values, axis=0, name="stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
@@ -1012,6 +1027,7 @@ ops.register_tensor_conversion_function((list, tuple),
_autopacking_conversion_function, 99)
+@tf_export("unstack")
def unstack(value, num=None, axis=0, name="unstack"):
"""Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
@@ -1061,6 +1077,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
+@tf_export("concat")
def concat(values, axis, name="concat"):
"""Concatenates tensors along one dimension.
@@ -1157,6 +1174,7 @@ def concat(values, axis, name="concat"):
return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
+@tf_export("boolean_mask")
def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`.
@@ -1237,6 +1255,7 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
return _apply_mask_1d(tensor, mask, axis)
+@tf_export("sparse_mask")
def sparse_mask(a, mask_indices, name=None):
"""Masks elements of `IndexedSlices`.
@@ -1279,6 +1298,7 @@ def sparse_mask(a, mask_indices, name=None):
return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
+@tf_export("unique")
def unique(x, out_idx=dtypes.int32, name=None):
# TODO(yongtang): switch to v2 once API deprecation
# period (3 weeks) pass.
@@ -1290,6 +1310,7 @@ def unique(x, out_idx=dtypes.int32, name=None):
unique.__doc__ = gen_array_ops._unique.__doc__
+@tf_export("split")
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor into sub tensors.
@@ -1356,6 +1377,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
name=name)
+@tf_export("transpose")
def transpose(a, perm=None, name="transpose", conjugate=False):
"""Transposes `a`. Permutes the dimensions according to `perm`.
@@ -1432,6 +1454,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
# pylint: disable=invalid-name
+@tf_export("matrix_transpose", "linalg.transpose")
def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`.
@@ -1503,6 +1526,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
# pylint: enable=invalid-name
+@tf_export("zeros")
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
@@ -1547,6 +1571,7 @@ def zeros(shape, dtype=dtypes.float32, name=None):
return output
+@tf_export("zeros_like")
def zeros_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to zero.
@@ -1563,9 +1588,9 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
- dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
- `complex64`, `complex128` or `bool`.
+ dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
+ `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
+ `complex64`, `complex128`, `bool` or `string`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
@@ -1599,6 +1624,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
return gen_array_ops._zeros_like(tensor, name=name)
+@tf_export("ones_like")
def ones_like(tensor, dtype=None, name=None, optimize=True):
"""Creates a tensor with all elements set to 1.
@@ -1636,6 +1662,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
return ret
+@tf_export("ones")
def ones(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to 1.
@@ -1675,6 +1702,7 @@ def ones(shape, dtype=dtypes.float32, name=None):
return output
+@tf_export("placeholder")
def placeholder(dtype, shape=None, name=None):
"""Inserts a placeholder for a tensor that will be always fed.
@@ -1728,6 +1756,7 @@ def _normalize_sparse_shape(shape, name):
return (ops.convert_to_tensor(shape, dtype=dtypes.int64, name=name), rank)
+@tf_export("sparse_placeholder")
def sparse_placeholder(dtype, shape=None, name=None):
"""Inserts a placeholder for a sparse tensor that will be always fed.
@@ -1794,6 +1823,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
# pylint: enable=redefined-outer-name
+@tf_export("pad")
def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name
"""Pads a tensor.
@@ -1887,6 +1917,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pyl
return result
+@tf_export("meshgrid")
def meshgrid(*args, **kwargs):
"""Broadcasts parameters for evaluation on an N-D grid.
@@ -2026,6 +2057,7 @@ def _TileGradShape(op):
return [tensor_shape.TensorShape(output_dims)]
+@tf_export("edit_distance")
def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
"""Computes the Levenshtein distance between sequences.
@@ -2139,6 +2171,7 @@ def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad):
narrow_range=op.get_attr("narrow_range"))
+@tf_export("required_space_to_batch_paddings")
def required_space_to_batch_paddings(input_shape,
block_shape,
base_paddings=None,
@@ -2217,6 +2250,7 @@ def required_space_to_batch_paddings(input_shape,
return result_paddings, result_crops
+@tf_export("space_to_batch")
def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin
result = space_to_batch_nd(
input,
@@ -2230,6 +2264,7 @@ def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=r
space_to_batch.__doc__ = gen_array_ops._space_to_batch.__doc__
+@tf_export("space_to_depth")
def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@@ -2237,6 +2272,7 @@ def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint:
space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
+@tf_export("depth_to_space")
def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@@ -2244,6 +2280,7 @@ def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint:
depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
+@tf_export("batch_to_space")
def batch_to_space(input, crops, block_size, name=None): # pylint: disable=redefined-builtin
result = batch_to_space_nd(
input,
@@ -2257,6 +2294,7 @@ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=rede
batch_to_space.__doc__ = gen_array_ops._batch_to_space.__doc__
+@tf_export("one_hot")
def one_hot(indices,
depth,
on_value=None,
@@ -2416,6 +2454,7 @@ def _all_dimensions(x):
return range(0, rank(x))
+@tf_export("sequence_mask")
def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
"""Returns a mask tensor representing the first N positions of each cell.
@@ -2478,6 +2517,7 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
return gen_math_ops.cast(result, dtype)
+@tf_export("squeeze")
def squeeze(input, axis=None, name=None, squeeze_dims=None):
# pylint: disable=redefined-builtin
"""Removes dimensions of size 1 from the shape of a tensor.
@@ -2527,6 +2567,7 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
return gen_array_ops._squeeze(input, axis, name)
+@tf_export("where")
def where(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`.
@@ -2579,6 +2620,7 @@ def where(condition, x=None, y=None, name=None):
raise ValueError("x and y must both be non-None or both be None.")
+@tf_export("reverse")
def reverse(tensor, axis, name=None):
return gen_array_ops.reverse_v2(tensor, axis, name)
@@ -2587,6 +2629,7 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__
# pylint: disable=redefined-builtin
+@tf_export("reverse_sequence")
def reverse_sequence(input,
seq_lengths,
seq_axis=None,
@@ -2614,6 +2657,7 @@ reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring(
"seq_dim", "seq_axis")
+@tf_export("gather")
def gather(params, indices, validate_indices=None, name=None, axis=0):
# TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward
# compatibility 3 week period has passed.
@@ -2629,6 +2673,7 @@ gather.__doc__ = gen_array_ops.gather_v2.__doc__
# Define quantize_v2 here in order to make name the second-to-last attribute,
# because round_mode was added later.
+@tf_export("quantize_v2")
@deprecation.deprecated(
"2017-10-25",
"`tf.quantize_v2` is deprecated, please use `tf.quantize` instead.")
@@ -2653,6 +2698,7 @@ quantize_v2.__doc__ = """Please use `tf.quantize` instead."""
# We want to expose tf.quantize instead of tf.quantize_v2; we can deprecate
# tf.quantize_v2 in next version of TensorFlow.
+@tf_export("quantize")
def quantize(input, # pylint: disable=redefined-builtin
min_range,
max_range,
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
index d6294c24f5..20445c78a2 100644
--- a/tensorflow/python/ops/candidate_sampling_ops.py
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -23,8 +23,10 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_candidate_sampling_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('nn.uniform_candidate_sampler')
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a uniform base distribution.
@@ -80,6 +82,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
seed2=seed2, name=name)
+@tf_export('nn.log_uniform_candidate_sampler')
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a log-uniform (Zipfian) base distribution.
@@ -138,6 +141,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
seed2=seed2, name=name)
+@tf_export('nn.learned_unigram_candidate_sampler')
def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
unique, range_max, seed=None, name=None):
"""Samples a set of classes from a distribution learned during training.
@@ -194,6 +198,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
seed2=seed2, name=name)
+@tf_export('nn.fixed_unigram_candidate_sampler')
def fixed_unigram_candidate_sampler(true_classes,
num_true,
num_sampled,
@@ -285,6 +290,7 @@ def fixed_unigram_candidate_sampler(true_classes,
unigrams=unigrams, seed=seed1, seed2=seed2, name=name)
+@tf_export('nn.all_candidate_sampler')
def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
seed=None, name=None):
"""Generate the set of all classes.
@@ -320,6 +326,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
name=name)
+@tf_export('nn.compute_accidental_hits')
def compute_accidental_hits(true_classes, sampled_candidates, num_true,
seed=None, name=None):
"""Compute the position ids in `sampled_candidates` matching `true_classes`.
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index eb7806ed0b..0fd6e29a49 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -57,6 +57,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset(
[dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
@@ -111,6 +112,7 @@ def _shape_and_dtype_str(tensor):
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
+@tf_export('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
@@ -138,6 +140,7 @@ def assert_proper_iterable(values):
'Expected argument "values" to be iterable. Found: %s' % type(values))
+@tf_export('assert_negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
@@ -178,6 +181,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less(x, zero, data=data, summarize=summarize)
+@tf_export('assert_positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
@@ -217,6 +221,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less(zero, x, data=data, summarize=summarize)
+@tf_export('assert_non_negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
@@ -258,6 +263,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(zero, x, data=data, summarize=summarize)
+@tf_export('assert_non_positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
@@ -299,6 +305,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(x, zero, data=data, summarize=summarize)
+@tf_export('assert_equal')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x == y` holds element-wise.
@@ -395,6 +402,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_none_equal')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
@@ -445,6 +453,7 @@ def assert_none_equal(
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
name=None):
@@ -522,6 +531,7 @@ def assert_near(
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_less')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < y` holds element-wise.
@@ -569,6 +579,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_less_equal')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
@@ -616,6 +627,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_greater')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > y` holds element-wise.
@@ -663,6 +675,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_greater_equal')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x >= y` holds element-wise.
@@ -760,6 +773,7 @@ def _assert_rank_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_rank')
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
@@ -821,6 +835,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
return assert_op
+@tf_export('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank` or higher.
@@ -951,6 +966,7 @@ def _assert_ranks_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
+@tf_export('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank in `ranks`.
@@ -1012,6 +1028,7 @@ def assert_rank_in(
return assert_op
+@tf_export('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
@@ -1049,6 +1066,7 @@ def assert_integer(x, message=None, name=None):
return control_flow_ops.no_op('statically_determined_was_integer')
+@tf_export('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
@@ -1096,10 +1114,12 @@ def _get_diff_for_monotonic_comparison(x):
return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
+@tf_export('is_numeric_tensor')
def is_numeric_tensor(tensor):
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
+@tf_export('is_non_decreasing')
def is_non_decreasing(x, name=None):
"""Returns `True` if `x` is non-decreasing.
@@ -1126,6 +1146,7 @@ def is_non_decreasing(x, name=None):
return math_ops.reduce_all(math_ops.less_equal(zero, diff))
+@tf_export('is_strictly_increasing')
def is_strictly_increasing(x, name=None):
"""Returns `True` if `x` is strictly increasing.
@@ -1184,6 +1205,7 @@ def _assert_same_base_type(items, expected_type=None):
return expected_type
+@tf_export('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
@@ -1212,6 +1234,7 @@ def assert_same_float_dtype(tensors=None, dtype=None):
return dtype
+@tf_export('assert_scalar')
def assert_scalar(tensor, name=None):
with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
tensor = ops.convert_to_tensor(tensor, name=name_scope)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 80803530c1..dd8c33247c 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -28,8 +28,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("clip_by_value")
def clip_by_value(t, clip_value_min, clip_value_max,
name=None):
"""Clips tensor values to a specified min and max.
@@ -70,6 +72,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
return t_max
+@tf_export("clip_by_norm")
def clip_by_norm(t, clip_norm, axes=None, name=None):
"""Clips tensor values to a maximum L2-norm.
@@ -117,6 +120,8 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
return tclip
+
+@tf_export("global_norm")
def global_norm(t_list, name=None):
"""Computes the global norm of multiple tensors.
@@ -164,6 +169,8 @@ def global_norm(t_list, name=None):
return norm
+
+@tf_export("clip_by_global_norm")
def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
"""Clips values of multiple tensors by the ratio of the sum of their norms.
@@ -246,6 +253,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
return list_clipped, use_norm
+@tf_export("clip_by_average_norm")
def clip_by_average_norm(t, clip_norm, name=None):
"""Clips tensor values to a maximum average L2-norm.
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index 32e071db17..50690cd891 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util.tf_export import tf_export
def remove_squeezable_dimensions(
@@ -93,6 +94,7 @@ def remove_squeezable_dimensions(
return labels, predictions
+@tf_export('confusion_matrix')
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 86941a7f2a..49191c647d 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Control Flow Operations.
See the @{$python/control_flow_ops} guide.
@@ -82,7 +81,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
-
+from tensorflow.python.util.tf_export import tf_export
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
@@ -117,6 +116,7 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
+@tf_export("Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
@@ -154,9 +154,10 @@ def Assert(condition, data, summarize=None, name=None):
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
raise errors.InvalidArgumentError(
- node_def=None, op=None,
- message="Expected '%s' to be true. Summarized data: %s" % (
- condition, "\n".join(data_str)))
+ node_def=None,
+ op=None,
+ message="Expected '%s' to be true. Summarized data: %s" %
+ (condition, "\n".join(data_str)))
return
with ops.name_scope(name, "Assert", [condition, data]) as name:
@@ -165,15 +166,15 @@ def Assert(condition, data, summarize=None, name=None):
# As a simple heuristic, we assume that string and int32 are
# on host to avoid the need to use cond. If it is not case,
# we will pay the price copying the tensor to host memory.
- return gen_logging_ops._assert(
- condition, data, summarize, name="Assert")
+ return gen_logging_ops._assert(condition, data, summarize, name="Assert")
else:
condition = ops.convert_to_tensor(condition, name="Condition")
+
def true_assert():
return gen_logging_ops._assert(
condition, data, summarize, name="Assert")
- guarded_assert = cond(
- condition, no_op, true_assert, name="AssertGuard")
+
+ guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
if context.in_eager_mode():
return
return guarded_assert.op
@@ -213,7 +214,7 @@ def _Identity(data, name=None):
def _NextIteration(data, name=None):
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
- if data.dtype._is_ref_dtype: # pylint: disable=protected-access
+ if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(data, name=name)
else:
return next_iteration(data, name=name)
@@ -232,8 +233,13 @@ def _NextIteration(data, name=None):
return sparse_tensor.SparseTensor(indices, values, dense_shape)
-def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
- use_ref=True, use_input_shape=True, name=None):
+def _Enter(data,
+ frame_name,
+ is_constant=False,
+ parallel_iterations=10,
+ use_ref=True,
+ use_input_shape=True,
+ name=None):
"""Creates or finds a child frame, and makes `data` available to it.
The unique `frame_name` is used by the `Executor` to identify frames. If
@@ -255,35 +261,51 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
- result = ref_enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ result = ref_enter(
+ data, frame_name, is_constant, parallel_iterations, name=name)
else:
- result = enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ result = enter(
+ data, frame_name, is_constant, parallel_iterations, name=name)
if use_input_shape:
result.set_shape(data.get_shape())
return result
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
- values = _Enter(data.values, frame_name, is_constant,
- parallel_iterations=parallel_iterations,
- use_input_shape=use_input_shape, name=name)
- indices = enter(data.indices, frame_name, is_constant,
- parallel_iterations, name="indices")
+ values = _Enter(
+ data.values,
+ frame_name,
+ is_constant,
+ parallel_iterations=parallel_iterations,
+ use_input_shape=use_input_shape,
+ name=name)
+ indices = enter(
+ data.indices,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="indices")
if use_input_shape:
indices.set_shape(data.indices.get_shape())
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
- dense_shape = enter(dense_shape, frame_name, is_constant,
- parallel_iterations, name="dense_shape")
+ dense_shape = enter(
+ dense_shape,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return ops.IndexedSlices(values, indices, dense_shape)
else:
- dense_shape = enter(data.dense_shape, frame_name, is_constant,
- parallel_iterations, name="dense_shape")
+ dense_shape = enter(
+ data.dense_shape,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return sparse_tensor.SparseTensor(indices, values, dense_shape)
@@ -442,8 +464,10 @@ def merge(inputs, name=None):
if any([inp is None for inp in inputs]):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
- inputs = [ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
- for inp in inputs]
+ inputs = [
+ ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
+ for inp in inputs
+ ]
if all([isinstance(v, ops.Tensor) for v in inputs]):
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
return gen_control_flow_ops._ref_merge(inputs, name)
@@ -473,6 +497,8 @@ def merge(inputs, name=None):
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+
+
# pylint: enable=protected-access
@@ -486,7 +512,9 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array):
def _make_tensor_array(ta, t_or_flow):
# pylint: disable=protected-access
new_ta = tensor_array_ops.TensorArray(
- dtype=ta.dtype, handle=ta.handle, flow=t_or_flow,
+ dtype=ta.dtype,
+ handle=ta.handle,
+ flow=t_or_flow,
infer_shape=ta._infer_shape,
colocate_with_first_write_call=ta._colocate_with_first_write_call)
new_ta._colocate_with = ta._colocate_with
@@ -498,13 +526,13 @@ def _make_tensor_array(ta, t_or_flow):
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
raise ValueError(
- "Lengths of original Tensor list and new list do not match: %d vs. %d"
- % (len(tensors_or_tensorarrays), len(tensors_or_flows)))
+ "Lengths of original Tensor list and new list do not match: %d vs. %d" %
+ (len(tensors_or_tensorarrays), len(tensors_or_flows)))
return [
_make_tensor_array(ta, t_or_flow)
- if isinstance(ta, tensor_array_ops.TensorArray)
- else t_or_flow
- for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)]
+ if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow
+ for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
+ ]
def _ShapeLessThanOrEqual(shape1, shape2):
@@ -543,8 +571,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
raise ValueError(
"The shape invariant specified for %s is not compatible with "
"the initial shape of the loop variable. It enters the loop "
- "with shape %s, but the specified shape invariant is %s."
- % (inp.name, inp.get_shape(), shape))
+ "with shape %s, but the specified shape invariant is %s." %
+ (inp.name, inp.get_shape(), shape))
var.set_shape(shape)
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
@@ -555,8 +583,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
"The shape invariant specified for %s is not compatible with "
"the initial shape of the values tensor of this IndexedSlices. "
"It enters the loop with shape %s, but the specified shape "
- "invariant is %s."
- % (inp.values.name, inp.values.get_shape(), shape))
+ "invariant is %s." % (inp.values.name, inp.values.get_shape(),
+ shape))
var.values.set_shape(shape)
var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))
if var.dense_shape is not None:
@@ -567,8 +595,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
"The shape invariant specified for %s is not compatible with "
"the initial shape of the shape tensor of this SparseTensor. "
"It enters the loop with shape %s, but the specified shape "
- "invariant is %s."
- % (inp.dense_shape.name, inp.dense_shape.get_shape(), shape))
+ "invariant is %s." % (inp.dense_shape.name,
+ inp.dense_shape.get_shape(), shape))
var.values.set_shape(tensor_shape.TensorShape([None]))
var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))
var.dense_shape.set_shape(shape)
@@ -597,8 +625,8 @@ def _EnforceShapeInvariant(merge_var, next_var):
"The shape for %s is not an invariant for the loop. It enters "
"the loop with shape %s, but has shape %s after one iteration. "
"Provide shape invariants using either the `shape_invariants` "
- "argument of tf.while_loop or set_shape() on the loop variables."
- % (merge_var.name, m_shape, n_shape))
+ "argument of tf.while_loop or set_shape() on the loop variables." %
+ (merge_var.name, m_shape, n_shape))
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(var))
@@ -621,9 +649,9 @@ def _EnforceShapeInvariant(merge_var, next_var):
"the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
"after one iteration. Provide shape invariants using either the "
"`shape_invariants` argument of tf.while_loop or set_shape() "
- "on the loop variables."
- % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
- n_values_shape, n_indices_shape, n_shape_shape))
+ "on the loop variables." %
+ (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
+ n_values_shape, n_indices_shape, n_shape_shape))
else:
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
@@ -635,12 +663,12 @@ def _EnforceShapeInvariant(merge_var, next_var):
not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or
not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):
raise ValueError(
- "The shape for %s is not an invariant for the loop. It enters "
- "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
- "after one iteration. Provide shape invariants using either "
- "the `shape_invariants` argument of tf.while_loop or set_shape() "
- "on the loop variables."
- % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
+ "The shape for %s is not an invariant for the loop. It enters "
+ "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
+ "after one iteration. Provide shape invariants using either "
+ "the `shape_invariants` argument of tf.while_loop or set_shape() "
+ "on the loop variables." %
+ (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
n_values_shape, n_indices_shape, n_shape_shape))
@@ -655,7 +683,7 @@ def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
# the types don't match.
# TODO(skyewm): call this for other cases below (needs testing)
_EnforceShapeInvariant(m, v)
- m.op._update_input(1, v) # pylint: disable=protected-access
+ m.op._update_input(1, v) # pylint: disable=protected-access
elif isinstance(m, ops.IndexedSlices):
# pylint: disable=protected-access
v = math_ops._as_indexed_slices(v, optimize=False)
@@ -718,8 +746,7 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside "
"XLA while_loop because maximum_iterations was not passed to "
- "the tf.while_loop call ('%s')."
- % (value_name, while_ctxt.name))
+ "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
# pylint: disable=protected-access
max_iter_ctxt = max_iter.op._get_control_flow_context()
@@ -740,9 +767,9 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
"while_loop. maximum_iterations tensor '%s' for while_loop context "
"'%s' must be statically known (e.g. a constant value or known "
"shape dimension), or be defined at or outside the while loop "
- "context '%s' (currently defined in '%s')." % (
- value_name, max_iter.name, while_ctxt.name,
- curr_ctxt_name, max_iter_ctxt.name))
+ "context '%s' (currently defined in '%s')." %
+ (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
+ max_iter_ctxt.name))
max_size *= const_max_iter
# Find the next outer WhileContext (or stop if we reach the
@@ -806,9 +833,11 @@ class GradLoopState(object):
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
- if outer_forward_ctxt: outer_forward_ctxt.Enter()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
- if outer_forward_ctxt: outer_forward_ctxt.Exit()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
@@ -833,7 +862,8 @@ class GradLoopState(object):
real_cnt, outer_grad_state)
outer_grad_ctxt.Exit()
else:
- if outer_forward_ctxt: outer_forward_ctxt.Enter()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
self._grad_context = WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
@@ -843,7 +873,8 @@ class GradLoopState(object):
grad_state=self)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
cnt, outer_grad_state)
- if outer_forward_ctxt: outer_forward_ctxt.Exit()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
@property
def outer_grad_state(self):
@@ -971,7 +1002,8 @@ class GradLoopState(object):
# curr_ctxt is the context that tf.gradients was called in.
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
with ops.control_dependencies(None):
- if curr_ctxt: curr_ctxt.Enter()
+ if curr_ctxt:
+ curr_ctxt.Enter()
with ops.colocate_with(value):
# We only need to pass maximum_iterations to the stack if
# we're inside an XLA context.
@@ -982,11 +1014,10 @@ class GradLoopState(object):
value, self.forward_context)
# pylint: disable=protected-access
acc = gen_data_flow_ops._stack_v2(
- max_size=max_size,
- elem_type=value.dtype.base_dtype,
- name="f_acc")
+ max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
# pylint: enable=protected-access
- if curr_ctxt: curr_ctxt.Exit()
+ if curr_ctxt:
+ curr_ctxt.Exit()
# Make acc available in the forward context.
enter_acc = self.forward_context.AddValue(acc)
@@ -1007,8 +1038,7 @@ class GradLoopState(object):
else:
# value is in a cond context within the forward context.
if not isinstance(value_ctxt, CondContext):
- raise TypeError(
- "value_ctxt is not a CondContext: %s" % value_ctxt)
+ raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
if dead_branch:
# The special case for creating a zero tensor for a dead
# branch of a switch. See ControlFlowState.ZerosLike().
@@ -1132,8 +1162,8 @@ class GradLoopState(object):
if real_value is None:
# Add the stack pop op in the grad context.
- real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value,
- cur_value)
+ real_value = cur_grad_state.AddBackpropAccumulatedValue(
+ history_value, cur_value)
if cur_grad_state != self:
real_value = self._grad_context.AddValue(real_value)
self._history_map[value.name] = real_value
@@ -1152,7 +1182,7 @@ class ControlFlowState(object):
"""Maintain the mapping from the loops to their grad states."""
def __init__(self):
- self._map = {} # maps forward loop context to GradLoopState
+ self._map = {} # maps forward loop context to GradLoopState
def GetGradState(self, op, before):
"""Return the grad state for this op if it's in a forward loop context."""
@@ -1316,7 +1346,8 @@ class ControlFlowState(object):
Returns:
A zero tensor of the same shape of op.outputs[index].
"""
- if util.IsLoopSwitch(op): return None
+ if util.IsLoopSwitch(op):
+ return None
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
@@ -1359,8 +1390,8 @@ class ControlFlowState(object):
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
- shape = grad_state.AddBackpropAccumulatedValue(
- history_zeros_shape, zeros_shape, dead_branch)
+ shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
+ zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
@@ -1391,12 +1422,14 @@ class ControlFlowState(object):
else:
# Create a zeros in the outer grad context.
outer_grad_ctxt = grad_state.grad_context.outer_context
- if outer_grad_ctxt: outer_grad_ctxt.Enter()
+ if outer_grad_ctxt:
+ outer_grad_ctxt.Enter()
enter_grad_op = b_merge.op.inputs[0].op
enter_grad = enter_grad_op.inputs[0]
grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
grad_val = array_ops.zeros(grad_shape)
- if outer_grad_ctxt: outer_grad_ctxt.Exit()
+ if outer_grad_ctxt:
+ outer_grad_ctxt.Exit()
# Use the zeros for iterations > 0.
grad_state.grad_context.Enter()
next_grad_val = _NextIteration(grad_val)
@@ -1468,8 +1501,7 @@ class ControlFlowContext(object):
self._outer_context = ops.get_default_graph()._get_control_flow_context()
self._context_stack = []
if values_def:
- self._init_values_from_proto(values_def,
- import_scope=import_scope)
+ self._init_values_from_proto(values_def, import_scope=import_scope)
else:
# Values that have been already seen in this context.
self._values = set()
@@ -1530,19 +1562,16 @@ class ControlFlowContext(object):
"""
values_def = control_flow_pb2.ValuesDef()
values_def.values.extend(
- [ops.strip_name_scope(v, export_scope)
- for v in sorted(self._values)])
+ [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
for k, v in self._external_values.items():
k = ops.strip_name_scope(k, export_scope)
- values_def.external_values[k] = ops.strip_name_scope(
- v.name, export_scope)
+ values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
return values_def
@staticmethod
def _from_proto(values_def, import_scope=None):
"""Returns a `ControlFlowContext` created from `values_def`."""
- return ControlFlowContext(values_def=values_def,
- import_scope=import_scope)
+ return ControlFlowContext(values_def=values_def, import_scope=import_scope)
def AddName(self, name):
self._values.add(name)
@@ -1597,6 +1626,7 @@ class ControlFlowContext(object):
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
return internal_control_inputs
+
# pylint: enable=protected-access
def AddInnerOp(self, op):
@@ -1624,8 +1654,13 @@ class ControlFlowContext(object):
class CondContext(ControlFlowContext):
"""The context for the conditional construct."""
- def __init__(self, pred=None, pivot=None, branch=None,
- name="cond_text", context_def=None, import_scope=None):
+ def __init__(self,
+ pred=None,
+ pivot=None,
+ branch=None,
+ name="cond_text",
+ context_def=None,
+ import_scope=None):
"""Creates a `CondContext`.
Args:
@@ -1645,9 +1680,9 @@ class CondContext(ControlFlowContext):
else:
# Initializes the default fields.
ControlFlowContext.__init__(self)
- self._pred = pred # The boolean tensor for the cond predicate
- self._pivot = pivot # The predicate tensor in this branch
- self._branch = branch # 0 or 1 representing this branch
+ self._pred = pred # The boolean tensor for the cond predicate
+ self._pivot = pivot # The predicate tensor in this branch
+ self._branch = branch # 0 or 1 representing this branch
# Values considered to have been already seen in this context.
self._values.add(pred.name)
@@ -1663,15 +1698,14 @@ class CondContext(ControlFlowContext):
assert isinstance(context_def, control_flow_pb2.CondContextDef)
# Create from context_def.
g = ops.get_default_graph()
- self._name = ops.prepend_name_scope(
- context_def.context_name, import_scope)
- self._pred = g.as_graph_element(ops.prepend_name_scope(
- context_def.pred_name, import_scope))
- self._pivot = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_name, import_scope))
+ self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
+ self._pred = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pred_name, import_scope))
+ self._pivot = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_name, import_scope))
self._branch = context_def.branch
- super(CondContext, self).__init__(values_def=context_def.values_def,
- import_scope=import_scope)
+ super(CondContext, self).__init__(
+ values_def=context_def.values_def, import_scope=import_scope)
@property
def pred(self):
@@ -1709,18 +1743,16 @@ class CondContext(ControlFlowContext):
Returns:
A `CondContextDef` protocol buffer.
"""
- if (export_scope is None or
- self.name.startswith(export_scope)):
+ if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.CondContextDef()
- context_def.context_name = ops.strip_name_scope(
- self.name, export_scope)
- context_def.pred_name = ops.strip_name_scope(
- self._pred.name, export_scope)
- context_def.pivot_name = ops.strip_name_scope(
- self._pivot.name, export_scope)
+ context_def.context_name = ops.strip_name_scope(self.name, export_scope)
+ context_def.pred_name = ops.strip_name_scope(self._pred.name,
+ export_scope)
+ context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
+ export_scope)
context_def.branch = self._branch
- context_def.values_def.MergeFrom(super(CondContext, self)._to_proto(
- export_scope))
+ context_def.values_def.MergeFrom(
+ super(CondContext, self)._to_proto(export_scope))
return context_def
else:
@@ -1729,8 +1761,7 @@ class CondContext(ControlFlowContext):
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `CondContext` object created from `context_def`."""
- return CondContext(context_def=context_def,
- import_scope=import_scope)
+ return CondContext(context_def=context_def, import_scope=import_scope)
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
@@ -1844,8 +1875,8 @@ class CondContext(ControlFlowContext):
if original_result is None:
return no_op(), None
else:
- original_result = nest.map_structure(
- array_ops.identity, original_result)
+ original_result = nest.map_structure(array_ops.identity,
+ original_result)
if original_result is None:
return None, None
@@ -1867,12 +1898,17 @@ def _UnpackIfSingleton(res):
# pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args
+@tf_export("cond")
@deprecation.deprecated_args(
- None,
- "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
+ None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
-def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
- fn1=None, fn2=None):
+def cond(pred,
+ true_fn=None,
+ false_fn=None,
+ strict=False,
+ name=None,
+ fn1=None,
+ fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
@@ -2041,6 +2077,8 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
if not strict:
merges = _UnpackIfSingleton(merges)
return merges
+
+
# pylint: enable=g-doc-args
# pylint: enable=redefined-outer-name
@@ -2136,8 +2174,7 @@ class WhileContext(ControlFlowContext):
assert isinstance(context_def, control_flow_pb2.WhileContextDef)
# Create from context_def.
g = ops.get_default_graph()
- self._name = ops.prepend_name_scope(
- context_def.context_name, import_scope)
+ self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
if context_def.maximum_iterations_name:
self._maximum_iterations = g.as_graph_element(
ops.prepend_name_scope(context_def.maximum_iterations_name,
@@ -2147,25 +2184,27 @@ class WhileContext(ControlFlowContext):
self._parallel_iterations = context_def.parallel_iterations
self._back_prop = context_def.back_prop
self._swap_memory = context_def.swap_memory
- self._pivot_for_pred = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_for_pred_name, import_scope))
+ self._pivot_for_pred = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
# We use this node to control constants created by the body lambda.
- self._pivot_for_body = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_for_body_name, import_scope))
+ self._pivot_for_body = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation.
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
# The list of exit tensors for loop variables.
- self._loop_exits = [g.as_graph_element(
- ops.prepend_name_scope(exit_name, import_scope))
- for exit_name in context_def.loop_exit_names]
+ self._loop_exits = [
+ g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
+ for exit_name in context_def.loop_exit_names
+ ]
# The list of enter tensors for loop variables.
- self._loop_enters = [g.as_graph_element(
- ops.prepend_name_scope(enter_name, import_scope))
- for enter_name in context_def.loop_enter_names]
- super(WhileContext, self).__init__(values_def=context_def.values_def,
- import_scope=import_scope)
+ self._loop_enters = [
+ g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
+ for enter_name in context_def.loop_enter_names
+ ]
+ super(WhileContext, self).__init__(
+ values_def=context_def.values_def, import_scope=import_scope)
@property
def maximum_iterations(self):
@@ -2216,11 +2255,9 @@ class WhileContext(ControlFlowContext):
Returns:
A `WhileContextDef` protocol buffer.
"""
- if (export_scope is None or
- self.name.startswith(export_scope)):
+ if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.WhileContextDef()
- context_def.context_name = ops.strip_name_scope(
- self.name, export_scope)
+ context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.parallel_iterations = self._parallel_iterations
if self._maximum_iterations is not None:
context_def.maximum_iterations_name = ops.strip_name_scope(
@@ -2231,17 +2268,16 @@ class WhileContext(ControlFlowContext):
self._pivot_for_pred.name, export_scope)
context_def.pivot_for_body_name = ops.strip_name_scope(
self._pivot_for_body.name, export_scope)
- context_def.pivot_name = ops.strip_name_scope(
- self._pivot.name, export_scope)
- context_def.loop_exit_names.extend(
- [ops.strip_name_scope(l.name, export_scope)
- for l in self._loop_exits])
- context_def.loop_enter_names.extend(
- [ops.strip_name_scope(l.name, export_scope)
- for l in self._loop_enters])
+ context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
+ export_scope)
+ context_def.loop_exit_names.extend([
+ ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
+ ])
+ context_def.loop_enter_names.extend([
+ ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
+ ])
context_def.values_def.MergeFrom(
- super(WhileContext, self)._to_proto(
- export_scope=export_scope))
+ super(WhileContext, self)._to_proto(export_scope=export_scope))
return context_def
else:
@@ -2258,8 +2294,7 @@ class WhileContext(ControlFlowContext):
Returns:
A `WhileContext` Python object.
"""
- return WhileContext(context_def=context_def,
- import_scope=import_scope)
+ return WhileContext(context_def=context_def, import_scope=import_scope)
def GetWhileContext(self):
return self
@@ -2296,8 +2331,11 @@ class WhileContext(ControlFlowContext):
result = self._outer_context.AddValue(val)
# Create an Enter to make `result` known to this loop context.
with ops.control_dependencies(None):
- enter = _Enter(result, self._name, is_constant=True,
- parallel_iterations=self._parallel_iterations)
+ enter = _Enter(
+ result,
+ self._name,
+ is_constant=True,
+ parallel_iterations=self._parallel_iterations)
enter.graph.prevent_feeding(enter)
if self._outer_context:
self._outer_context.AddInnerOp(enter.op)
@@ -2375,6 +2413,7 @@ class WhileContext(ControlFlowContext):
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""
+
def _IsOpFree(op):
"""Determines if `op` needs a control dependency."""
if op.control_inputs:
@@ -2387,6 +2426,7 @@ class WhileContext(ControlFlowContext):
if not util.IsLoopConstantEnter(x.op):
return False
return True
+
if _IsOpFree(op):
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
@@ -2420,9 +2460,12 @@ class WhileContext(ControlFlowContext):
self.Enter()
self.AddName(n.name)
- enter_n = _Enter(n, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="f_count")
+ enter_n = _Enter(
+ n,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_count")
self.loop_enters.append(enter_n)
merge_n = merge([enter_n, enter_n])[0]
@@ -2462,9 +2505,12 @@ class WhileContext(ControlFlowContext):
self.Enter()
self.AddName(count.name)
- enter_count = _Enter(count, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="b_count")
+ enter_count = _Enter(
+ count,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_count")
self.loop_enters.append(enter_count)
merge_count = merge([enter_count, enter_count])[0]
@@ -2522,9 +2568,11 @@ class WhileContext(ControlFlowContext):
# without running any iterations.
shape = grad.get_shape()
if shape.is_fully_defined():
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
else:
value = op.inputs[0]
if (isinstance(self.outer_context, WhileContext) and
@@ -2543,16 +2591,21 @@ class WhileContext(ControlFlowContext):
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
acc = array_ops.zeros(zeros_shape, grad.dtype)
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
self.Enter()
self.AddName(acc.name)
- enter_acc = _Enter(acc, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="b_acc")
+ enter_acc = _Enter(
+ acc,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_acc")
self.loop_enters.append(enter_acc)
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
@@ -2585,14 +2638,17 @@ class WhileContext(ControlFlowContext):
dense_shape = grad.dense_shape
self.Exit()
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
if values.get_shape().is_fully_defined():
values_shape = tensor_shape.TensorShape(
[tensor_shape.Dimension(1)] + values.get_shape().dims[1:])
- if self.outer_context: self.outer_context.Enter()
- values_acc = constant_op.constant(0, values.dtype, shape=values_shape,
- name="b_acc")
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Enter()
+ values_acc = constant_op.constant(
+ 0, values.dtype, shape=values_shape, name="b_acc")
+ if self.outer_context:
+ self.outer_context.Exit()
else:
values_shape = _resource_safe_shape(op.inputs[0])[1:]
values_shape = array_ops.concat([[1], values_shape], 0)
@@ -2601,16 +2657,19 @@ class WhileContext(ControlFlowContext):
shape_acc = None
if dense_shape is not None:
if dense_shape.get_shape().is_fully_defined():
- if self.outer_context: self.outer_context.Enter()
- shape_acc = constant_op.constant(0, dense_shape.dtype,
- shape=dense_shape.get_shape())
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Enter()
+ shape_acc = constant_op.constant(
+ 0, dense_shape.dtype, shape=dense_shape.get_shape())
+ if self.outer_context:
+ self.outer_context.Exit()
else:
shape_acc = array_ops.zeros_like(
array_ops.shape_internal(op.inputs[0], optimize=False),
optimize=False)
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
self.Enter()
self.AddName(values_acc.name)
@@ -2623,9 +2682,15 @@ class WhileContext(ControlFlowContext):
# Set use_input_shape=False since the accumulator tensors will grow in
# size. If use_input_shape=True, the _update_input call below will result in
# incompatible shapes.
- enter_acc = [_Enter(x, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- use_input_shape=False, name="b_acc") for x in init_acc]
+ enter_acc = [
+ _Enter(
+ x,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ use_input_shape=False,
+ name="b_acc") for x in init_acc
+ ]
# Manually set appropriate partial shapes.
enter_acc[0].set_shape([None])
if values_acc.shape.dims is not None:
@@ -2642,8 +2707,7 @@ class WhileContext(ControlFlowContext):
]
if shape_acc is not None:
# For the shape we just keep the maximum
- acc_indexed_slices.append(
- math_ops.maximum(dense_shape, switch_acc[2][1]))
+ acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
next_acc = [_NextIteration(x) for x in acc_indexed_slices]
for xm, xn in zip(merge_acc, next_acc):
@@ -2654,7 +2718,8 @@ class WhileContext(ControlFlowContext):
self.ExitResult(exit_acc)
return ops.IndexedSlices(
- indices=exit_acc[0], values=exit_acc[1],
+ indices=exit_acc[0],
+ values=exit_acc[1],
dense_shape=exit_acc[2] if shape_acc is not None else None)
def _InitializeValues(self, values):
@@ -2687,10 +2752,14 @@ class WhileContext(ControlFlowContext):
if self._outer_context:
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
with ops.control_dependencies(None):
- enter_vars = [_Enter(x, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- use_input_shape=(shape_invariants is None))
- for x in real_vars]
+ enter_vars = [
+ _Enter(
+ x,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ use_input_shape=(shape_invariants is None)) for x in real_vars
+ ]
for x in enter_vars:
x.graph.prevent_feeding(x)
if self._outer_context:
@@ -2751,11 +2820,13 @@ class WhileContext(ControlFlowContext):
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
+
def map_fn(x):
# TODO(apassos) figure out how to trigger with tensor arrays as well
if isinstance(x, tensor_array_ops.TensorArray):
return x
return array_ops.identity(x)
+
body_result = nest.map_structure(map_fn, body_result)
# Compare the structure types of input and output of body.
@@ -2812,8 +2883,7 @@ class WhileContext(ControlFlowContext):
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensor_arrays)
- return (packed_exit_vars[0] if len(exit_vars) == 1
- else packed_exit_vars)
+ return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars)
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
@@ -2831,8 +2901,9 @@ class WhileContext(ControlFlowContext):
for x in xs:
inp_op = x.op.inputs[0].op
control_inputs = graph._control_dependencies_for_inputs([inp_op])
- outer_control_inputs = [op for op in control_inputs
- if self._IsInOuterContext(op)]
+ outer_control_inputs = [
+ op for op in control_inputs if self._IsInOuterContext(op)
+ ]
x.op._set_control_flow_context(self)
x.op._add_control_inputs(outer_control_inputs)
graph._record_op_seen_by_control_dependencies(x.op)
@@ -2843,9 +2914,16 @@ class WhileContext(ControlFlowContext):
# pylint: disable=redefined-outer-name
-def while_loop(cond, body, loop_vars, shape_invariants=None,
- parallel_iterations=10, back_prop=True, swap_memory=False,
- name=None, maximum_iterations=None):
+@tf_export("while_loop")
+def while_loop(cond,
+ body,
+ loop_vars,
+ shape_invariants=None,
+ parallel_iterations=10,
+ back_prop=True,
+ swap_memory=False,
+ name=None,
+ maximum_iterations=None):
"""Repeat `body` while the condition `cond` is true.
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
@@ -3020,6 +3098,8 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
return result[1]
else:
return result
+
+
# pylint: enable=redefined-outer-name
@@ -3047,8 +3127,9 @@ def _AsTensorList(x, p):
if isinstance(v, ops.Tensor):
l.append(array_ops.identity(v))
else:
- l.append(ops.IndexedSlices(array_ops.identity(v.values),
- array_ops.identity(v.indices)))
+ l.append(
+ ops.IndexedSlices(
+ array_ops.identity(v.values), array_ops.identity(v.indices)))
return l
@@ -3058,8 +3139,7 @@ def _CheckResults(a, b):
for x, y in zip(a, b):
assert x.dtype == y.dtype, (
"Values returned by a() [%s] and b() [%s] must have "
- "the same type: %s, %s." %
- (x.name, y.name, x.dtype.name, y.dtype.name))
+ "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
def with_dependencies(dependencies, output_tensor, name=None):
@@ -3095,9 +3175,9 @@ def with_dependencies(dependencies, output_tensor, name=None):
if isinstance(output_tensor, ops.Tensor):
return _Identity(output_tensor, name=name)
else:
- return ops.IndexedSlices(_Identity(output_tensor.values, name=name),
- output_tensor.indices,
- output_tensor.dense_shape)
+ return ops.IndexedSlices(
+ _Identity(output_tensor.values, name=name), output_tensor.indices,
+ output_tensor.dense_shape)
def _GroupControlDeps(dev, deps, name=None):
@@ -3110,6 +3190,7 @@ def _GroupControlDeps(dev, deps, name=None):
# TODO(touts): Accept "inputs" as a list.
+@tf_export("group")
def group(*inputs, **kwargs):
"""Create an op that groups multiple operations.
@@ -3168,6 +3249,7 @@ def group(*inputs, **kwargs):
def device_key(dev):
"""A sort key that allows None to be compared to strings."""
return "" if dev is None else dev
+
for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
@@ -3175,6 +3257,7 @@ def group(*inputs, **kwargs):
return no_op(name=name)
+@tf_export("tuple")
def tuple(tensors, name=None, control_inputs=None):
"""Group tensors together.
@@ -3328,6 +3411,7 @@ def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name):
return predicates, actions
+@tf_export("case")
def case(pred_fn_pairs,
default=None,
exclusive=False,
@@ -3456,12 +3540,14 @@ class XLAControlFlowContext(ControlFlowContext):
return x
-ops.register_proto_function(ops.GraphKeys.COND_CONTEXT,
- proto_type=control_flow_pb2.CondContextDef,
- to_proto=CondContext.to_proto,
- from_proto=CondContext.from_proto)
+ops.register_proto_function(
+ ops.GraphKeys.COND_CONTEXT,
+ proto_type=control_flow_pb2.CondContextDef,
+ to_proto=CondContext.to_proto,
+ from_proto=CondContext.from_proto)
-ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT,
- proto_type=control_flow_pb2.WhileContextDef,
- to_proto=WhileContext.to_proto,
- from_proto=WhileContext.from_proto)
+ops.register_proto_function(
+ ops.GraphKeys.WHILE_CONTEXT,
+ proto_type=control_flow_pb2.WhileContextDef,
+ to_proto=WhileContext.to_proto,
+ from_proto=WhileContext.from_proto)
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index f037767cf4..83da6739db 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -25,9 +25,11 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_ctc_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access, invalid-name
+@tf_export("nn.ctc_loss")
def ctc_loss(labels, inputs, sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
@@ -185,6 +187,7 @@ def _CTCLossGrad(op, grad_loss, _):
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
+@tf_export("nn.ctc_greedy_decoder")
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
"""Performs greedy decoding on the logits given in input (best path).
@@ -228,6 +231,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
log_probabilities)
+@tf_export("nn.ctc_beam_search_decoder")
def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
top_paths=1, merge_repeated=True):
"""Performs beam search decoding on the logits given in input.
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index f441f6d4bf..95e45bff06 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#==============================================================================
-
"""Data Flow Operations."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@@ -39,6 +38,8 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
+from tensorflow.python.util.tf_export import tf_export
+
# pylint: enable=wildcard-import
@@ -53,17 +54,19 @@ def _as_type_list(dtypes):
return list(dtypes)
-def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False,
+def _as_shape_list(shapes,
+ dtypes,
+ unknown_dim_allowed=False,
unknown_rank_allowed=False):
"""Convert shapes to a list of tuples of int (or None)."""
del dtypes
if unknown_dim_allowed:
- if (not isinstance(shapes, collections.Sequence)
- or not shapes
- or any(shape is None or isinstance(shape, int) for shape in shapes)):
+ if (not isinstance(shapes, collections.Sequence) or not shapes or
+ any(shape is None or isinstance(shape, int) for shape in shapes)):
raise ValueError(
"When providing partial shapes, a list of shapes must be provided.")
- if shapes is None: return None
+ if shapes is None:
+ return None
if isinstance(shapes, tensor_shape.TensorShape):
shapes = [shapes]
if not isinstance(shapes, (tuple, list)):
@@ -102,11 +105,13 @@ def _shape_common(s1, s2):
return tensor_shape.unknown_shape()
d = [
d1 if d1 is not None and d1 == d2 else None
- for (d1, d2) in zip(s1.as_list(), s2.as_list())]
+ for (d1, d2) in zip(s1.as_list(), s2.as_list())
+ ]
return tensor_shape.TensorShape(d)
# pylint: disable=protected-access
+@tf_export("QueueBase")
class QueueBase(object):
"""Base class for queue implementations.
@@ -193,8 +198,7 @@ class QueueBase(object):
TypeError: When `queues` is not a list of `QueueBase` objects,
or when the data types of `queues` are not all the same.
"""
- if ((not queues) or
- (not isinstance(queues, list)) or
+ if ((not queues) or (not isinstance(queues, list)) or
(not all(isinstance(x, QueueBase) for x in queues))):
raise TypeError("A list of queues expected")
@@ -208,12 +212,16 @@ class QueueBase(object):
queue_shapes = [q.shapes for q in queues]
reduced_shapes = [
- six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)]
+ six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
+ ]
queue_refs = array_ops.stack([x.queue_ref for x in queues])
selected_queue = array_ops.gather(queue_refs, index)
- return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names,
- queue_ref=selected_queue)
+ return QueueBase(
+ dtypes=dtypes,
+ shapes=reduced_shapes,
+ names=names,
+ queue_ref=selected_queue)
@property
def queue_ref(self):
@@ -280,8 +288,8 @@ class QueueBase(object):
tensors = []
for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
- tensors.append(ops.convert_to_tensor(val, dtype=dtype,
- name="component_%d" % i))
+ tensors.append(
+ ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
return tensors
@@ -553,11 +561,13 @@ class QueueBase(object):
name = "%s_Close" % self._name
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_close_v2(
- self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ self._queue_ref,
+ cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
else:
return gen_data_flow_ops._queue_close(
- self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ self._queue_ref,
+ cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
def is_closed(self, name=None):
@@ -575,9 +585,9 @@ class QueueBase(object):
if name is None:
name = "%s_Is_Closed" % self._name
if self._queue_ref.dtype == _dtypes.resource:
- return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref,name=name)
+ return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
else:
- return gen_data_flow_ops.queue_is_closed_(self._queue_ref,name=name)
+ return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
def size(self, name=None):
"""Compute the number of elements in this queue.
@@ -596,6 +606,7 @@ class QueueBase(object):
return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
+@tf_export("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
"""A queue implementation that dequeues elements in a random order.
@@ -608,8 +619,14 @@ class RandomShuffleQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None,
- names=None, seed=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ min_after_dequeue,
+ dtypes,
+ shapes=None,
+ names=None,
+ seed=None,
+ shared_name=None,
name="random_shuffle_queue"):
"""Create a queue that dequeues elements in a random order.
@@ -667,13 +684,19 @@ class RandomShuffleQueue(QueueBase):
string = (str(seed1) + shared_name).encode("utf-8")
seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ min_after_dequeue=min_after_dequeue,
+ seed=seed1,
+ seed2=seed2,
+ shared_name=shared_name,
+ name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
+@tf_export("FIFOQueue")
class FIFOQueue(QueueBase):
"""A queue implementation that dequeues elements in first-in first-out order.
@@ -686,8 +709,13 @@ class FIFOQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, dtypes, shapes=None, names=None,
- shared_name=None, name="fifo_queue"):
+ def __init__(self,
+ capacity,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ name="fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
A `FIFOQueue` has bounded capacity; supports multiple concurrent
@@ -721,12 +749,16 @@ class FIFOQueue(QueueBase):
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = gen_data_flow_ops._fifo_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
+@tf_export("PaddingFIFOQueue")
class PaddingFIFOQueue(QueueBase):
"""A FIFOQueue that supports batching variable-sized tensors by padding.
@@ -742,7 +774,12 @@ class PaddingFIFOQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, dtypes, shapes, names=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ dtypes,
+ shapes,
+ names=None,
+ shared_name=None,
name="padding_fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
@@ -787,16 +824,20 @@ class PaddingFIFOQueue(QueueBase):
names = _as_name_list(names, dtypes)
if len(dtypes) != len(shapes):
raise ValueError("Shapes must be provided for all components, "
- "but received %d dtypes and %d shapes."
- % (len(dtypes), len(shapes)))
+ "but received %d dtypes and %d shapes." % (len(dtypes),
+ len(shapes)))
queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
+@tf_export("PriorityQueue")
class PriorityQueue(QueueBase):
"""A queue implementation that dequeues elements in prioritized order.
@@ -809,7 +850,12 @@ class PriorityQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, types, shapes=None, names=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ types,
+ shapes=None,
+ names=None,
+ shared_name=None,
name="priority_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
@@ -850,14 +896,17 @@ class PriorityQueue(QueueBase):
shapes = _as_shape_list(shapes, types)
queue_ref = gen_data_flow_ops._priority_queue_v2(
- component_types=types, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=types,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
priority_dtypes = [_dtypes.int64] + types
priority_shapes = [()] + shapes if shapes else shapes
- super(PriorityQueue, self).__init__(
- priority_dtypes, priority_shapes, names, queue_ref)
+ super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
+ queue_ref)
# TODO(josh11b): class BatchQueue(QueueBase):
@@ -937,8 +986,10 @@ class Barrier(object):
self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
self._barrier_ref = gen_data_flow_ops._barrier(
- component_types=self._types, shapes=self._shapes,
- shared_name=shared_name, name=name)
+ component_types=self._types,
+ shapes=self._shapes,
+ shared_name=shared_name,
+ name=name)
if context.in_graph_mode():
self._name = self._barrier_ref.op.name.split("/")[-1]
else:
@@ -1022,12 +1073,13 @@ class Barrier(object):
"""
if name is None:
name = "%s_BarrierTakeMany" % self._name
- ret = gen_data_flow_ops._barrier_take_many(self._barrier_ref,
- num_elements,
- self._types,
- allow_small_batch,
- timeout,
- name=name)
+ ret = gen_data_flow_ops._barrier_take_many(
+ self._barrier_ref,
+ num_elements,
+ self._types,
+ allow_small_batch,
+ timeout,
+ name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Barrier object.
@@ -1042,8 +1094,7 @@ class Barrier(object):
op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys
for output, shape in zip(op.outputs[2:], self._shapes): # value_list
output.set_shape(
- tensor_shape.TensorShape([batch_dim]).concatenate(
- shape))
+ tensor_shape.TensorShape([batch_dim]).concatenate(shape))
return ret
@@ -1106,6 +1157,7 @@ class Barrier(object):
self._barrier_ref, name=name)
+@tf_export("ConditionalAccumulatorBase")
class ConditionalAccumulatorBase(object):
"""A conditional accumulator for aggregating gradients.
@@ -1184,6 +1236,7 @@ class ConditionalAccumulatorBase(object):
name=name)
+@tf_export("ConditionalAccumulator")
class ConditionalAccumulator(ConditionalAccumulatorBase):
"""A conditional accumulator for aggregating gradients.
@@ -1263,6 +1316,7 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
return out
+@tf_export("SparseConditionalAccumulator")
class SparseConditionalAccumulator(ConditionalAccumulatorBase):
"""A conditional accumulator for aggregating sparse gradients.
@@ -1289,8 +1343,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
name="sparse_conditional_accumulator"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
dtype=dtype, shape=shape, shared_name=shared_name, name=name)
- super(SparseConditionalAccumulator,
- self).__init__(dtype, shape, accumulator_ref)
+ super(SparseConditionalAccumulator, self).__init__(dtype, shape,
+ accumulator_ref)
def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
"""Attempts to apply a gradient to the accumulator.
@@ -1359,8 +1413,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
local_step=local_step,
gradient_indices=math_ops.to_int64(grad_indices),
gradient_values=grad_values,
- gradient_shape=math_ops.to_int64([] if grad_shape is None else
- grad_shape),
+ gradient_shape=math_ops.to_int64([]
+ if grad_shape is None else grad_shape),
has_known_shape=(grad_shape is not None),
name=name)
@@ -1422,11 +1476,16 @@ class BaseStagingArea(object):
_identifier = 0
_lock = threading.Lock()
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- capacity=0, memory_limit=0):
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ capacity=0,
+ memory_limit=0):
if shared_name is None:
- self._name = (ops.get_default_graph()
- .unique_name(self.__class__.__name__))
+ self._name = (
+ ops.get_default_graph().unique_name(self.__class__.__name__))
elif isinstance(shared_name, six.string_types):
self._name = shared_name
else:
@@ -1523,8 +1582,9 @@ class BaseStagingArea(object):
(sorted(vals.keys()), sorted(self._names)))
# The order of values in `self._names` indicates the order in which the
# tensors in the dictionary `vals` must be listed.
- vals, indices, n = zip(*[(vals[k], i, k) for i, k in enumerate(self._names)
- if k in vals])
+ vals, indices, n = zip(*[(vals[k], i, k)
+ for i, k in enumerate(self._names)
+ if k in vals])
else:
if self._names:
raise ValueError("You must enqueue a dictionary in a staging area "
@@ -1532,7 +1592,7 @@ class BaseStagingArea(object):
if indices is None:
raise ValueError("Indices must be supplied when inserting a list "
- "of tensors")
+ "of tensors")
if len(indices) != len(vals):
raise ValueError("Number of indices '%s' doesn't match "
@@ -1544,8 +1604,8 @@ class BaseStagingArea(object):
# Sanity check number of values
if not len(vals) <= len(self._dtypes):
- raise ValueError("Unexpected number of inputs '%s' vs '%s'" % (
- len(vals), len(self._dtypes)))
+ raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
+ (len(vals), len(self._dtypes)))
tensors = []
@@ -1553,14 +1613,14 @@ class BaseStagingArea(object):
dtype, shape = self._dtypes[i], self._shapes[i]
# Check dtype
if not val.dtype == dtype:
- raise ValueError("Datatypes do not match. '%s' != '%s'" %(
- str(val.dtype), str(dtype)))
+ raise ValueError("Datatypes do not match. '%s' != '%s'" %
+ (str(val.dtype), str(dtype)))
# Check shape
val.get_shape().assert_is_compatible_with(shape)
- tensors.append(ops.convert_to_tensor(val, dtype=dtype,
- name="component_%d" % i))
+ tensors.append(
+ ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
return tensors, indices
@@ -1623,6 +1683,7 @@ class BaseStagingArea(object):
else:
return [vals]
+
class StagingArea(BaseStagingArea):
"""Class for staging inputs. No ordering guarantees.
@@ -1657,8 +1718,13 @@ class StagingArea(BaseStagingArea):
"""
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- capacity=0, memory_limit=0):
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ capacity=0,
+ memory_limit=0):
"""Constructs a staging area object.
The two optional lists, `shapes` and `names`, must be of the same length
@@ -1693,9 +1759,8 @@ class StagingArea(BaseStagingArea):
ValueError: If one of the arguments is invalid.
"""
- super(StagingArea, self).__init__(dtypes, shapes,
- names, shared_name,
- capacity, memory_limit)
+ super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
+ capacity, memory_limit)
def put(self, values, name=None):
"""Create an op that places a value into the staging area.
@@ -1717,14 +1782,18 @@ class StagingArea(BaseStagingArea):
self._scope_vals(values)) as scope:
# Hard-code indices for this staging area
- indices = (list(six.moves.range(len(values)))
- if isinstance(values, (list, tuple)) else None)
+ indices = (
+ list(six.moves.range(len(values)))
+ if isinstance(values, (list, tuple)) else None)
vals, _ = self._check_put_dtypes(values, indices)
with ops.colocate_with(self._coloc_op):
- op = gen_data_flow_ops.stage(values=vals, shared_name=self._name,
- name=scope, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ op = gen_data_flow_ops.stage(
+ values=vals,
+ shared_name=self._name,
+ name=scope,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return op
@@ -1732,7 +1801,7 @@ class StagingArea(BaseStagingArea):
with ops.colocate_with(self._coloc_op):
ret = get_fn()
- indices = list(six.moves.range(len(self._dtypes))) # Hard coded
+ indices = list(six.moves.range(len(self._dtypes))) # Hard coded
return self._get_return_value(ret, indices)
def get(self, name=None):
@@ -1760,10 +1829,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_get" % self._name
+ # pylint: disable=bad-continuation
fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
shared_name=self._name, name=name,
capacity=self._capacity,
memory_limit=self._memory_limit)
+ # pylint: enable=bad-continuation
return self.__internal_get(fn, name)
@@ -1788,10 +1859,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_peek" % self._name
+ # pylint: disable=bad-continuation
fn = lambda: gen_data_flow_ops.stage_peek(index,
dtypes=self._dtypes, shared_name=self._name,
name=name, capacity=self._capacity,
memory_limit=self._memory_limit)
+ # pylint: enable=bad-continuation
return self.__internal_get(fn, name)
@@ -1807,9 +1880,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_size" % self._name
- return gen_data_flow_ops.stage_size(name=name, shared_name=self._name,
- dtypes=self._dtypes, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return gen_data_flow_ops.stage_size(
+ name=name,
+ shared_name=self._name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def clear(self, name=None):
"""Clears the staging area.
@@ -1823,14 +1899,16 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_clear" % self._name
- return gen_data_flow_ops.stage_clear(name=name, shared_name=self._name,
- dtypes=self._dtypes, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return gen_data_flow_ops.stage_clear(
+ name=name,
+ shared_name=self._name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
class MapStagingArea(BaseStagingArea):
- """
- A `MapStagingArea` is a TensorFlow data structure that stores tensors across
- multiple steps, and exposes operations that can put and get tensors.
+ """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors.
Each `MapStagingArea` element is a (key, value) pair.
Only int64 keys are supported, other types should be
@@ -1843,7 +1921,8 @@ class MapStagingArea(BaseStagingArea):
It supports multiple concurrent producers and consumers; and
provides exactly-once delivery.
- Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors whose
+ Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
+ whose
dtypes are described by `dtypes`, and whose shapes are optionally described
by the `shapes` argument.
@@ -1887,10 +1966,16 @@ class MapStagingArea(BaseStagingArea):
associated with it are removed.
"""
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- ordered=False, capacity=0, memory_limit=0):
- """
- Args:
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ ordered=False,
+ capacity=0,
+ memory_limit=0):
+ """Args:
+
dtypes: A list of types. The length of dtypes must equal the number
of tensors in each element.
capacity: (Optional.) Maximum number of elements.
@@ -1916,9 +2001,8 @@ class MapStagingArea(BaseStagingArea):
"""
- super(MapStagingArea, self).__init__(dtypes, shapes,
- names, shared_name,
- capacity, memory_limit)
+ super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
+ capacity, memory_limit)
# Defer to different methods depending if the map is ordered
self._ordered = ordered
@@ -1941,8 +2025,7 @@ class MapStagingArea(BaseStagingArea):
self._clear_fn = gen_data_flow_ops.map_clear
def put(self, key, vals, indices=None, name=None):
- """
- Create an op that stores the (key, vals) pair in the staging area.
+ """Create an op that stores the (key, vals) pair in the staging area.
Incomplete puts are possible, preferably using a dictionary for vals
as the appropriate dtypes and shapes can be inferred from the value names
@@ -1964,7 +2047,8 @@ class MapStagingArea(BaseStagingArea):
The created op
Raises:
- ValueError: If the number or type of inputs don't match the staging area.
+ ValueError: If the number or type of inputs don't match the staging
+ area.
"""
with ops.name_scope(name, "%s_put" % self._name,
@@ -1973,10 +2057,15 @@ class MapStagingArea(BaseStagingArea):
vals, indices = self._check_put_dtypes(vals, indices)
with ops.colocate_with(self._coloc_op):
- op = self._put_fn(key, indices, vals, dtypes=self._dtypes,
- shared_name=self._name, name=scope,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ op = self._put_fn(
+ key,
+ indices,
+ vals,
+ dtypes=self._dtypes,
+ shared_name=self._name,
+ name=scope,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return op
def _get_indices_and_dtypes(self, indices=None):
@@ -1992,13 +2081,13 @@ class MapStagingArea(BaseStagingArea):
if all(isinstance(i, str) for i in indices):
if self._names is None:
raise ValueError("String indices provided '%s', but this Staging Area "
- "was not created with names." % indices)
+ "was not created with names." % indices)
try:
indices = [self._names.index(n) for n in indices]
except ValueError:
raise ValueError("Named index '%s' not in "
- "Staging Area names '%s'" % (n, self._names))
+ "Staging Area names '%s'" % (n, self._names))
elif all(isinstance(i, int) for i in indices):
pass
else:
@@ -2009,10 +2098,8 @@ class MapStagingArea(BaseStagingArea):
return indices, dtypes
-
def peek(self, key, indices=None, name=None):
- """
- Peeks at staging area data associated with the key.
+ """Peeks at staging area data associated with the key.
If the key is not in the staging area, it will block
until the associated (key, value) is inserted.
@@ -2035,22 +2122,22 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- result = self._peek_fn(key, shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ result = self._peek_fn(
+ key,
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return self._get_return_value(result, indices)
def get(self, key=None, indices=None, name=None):
- """
- If the key is provided, the associated (key, value)
- is returned from the staging area. If the key is not
- in the staging area, this method will block until
- the associated (key, value) is inserted.
+ """If the key is provided, the associated (key, value) is returned from the staging area.
+ If the key is not in the staging area, this method will block until
+ the associated (key, value) is inserted.
If no key is provided and the staging area is ordered,
the (key, value) with the smallest key will be returned.
Otherwise, a random (key, value) will be returned.
@@ -2075,12 +2162,10 @@ class MapStagingArea(BaseStagingArea):
return self._pop(key, indices=indices, name=name)
def _pop(self, key, indices=None, name=None):
- """
- Remove and return the associated (key, value)
- is returned from the staging area. If the key is not
- in the staging area, this method will block until
- the associated (key, value) is inserted.
+ """Remove and return the associated (key, value) is returned from the staging area.
+ If the key is not in the staging area, this method will block until
+ the associated (key, value) is inserted.
Args:
key: Key associated with the required data
indices: Partial list of tensors to retrieve (optional).
@@ -2098,21 +2183,21 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- result = self._pop_fn(key, shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ result = self._pop_fn(
+ key,
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return key, self._get_return_value(result, indices)
def _popitem(self, indices=None, name=None):
- """
- If the staging area is ordered,
- the (key, value) with the smallest key will be returned.
- Otherwise, a random (key, value) will be returned.
+ """If the staging area is ordered, the (key, value) with the smallest key will be returned.
+ Otherwise, a random (key, value) will be returned.
If the staging area is empty when this operation executes,
it will block until there is an element to dequeue.
@@ -2133,12 +2218,13 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- key, result = self._popitem_fn(shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ key, result = self._popitem_fn(
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
# Separate keys and results out from
# underlying namedtuple
@@ -2148,8 +2234,7 @@ class MapStagingArea(BaseStagingArea):
return key, result
def size(self, name=None):
- """
- Returns the number of elements in the staging area.
+ """Returns the number of elements in the staging area.
Args:
name: A name for the operation (optional)
@@ -2160,14 +2245,15 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_size" % self._name
- return self._size_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return self._size_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def incomplete_size(self, name=None):
- """
- Returns the number of incomplete elements in the staging area.
+ """Returns the number of incomplete elements in the staging area.
Args:
name: A name for the operation (optional)
@@ -2178,16 +2264,15 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_incomplete_size" % self._name
- return self._incomplete_size_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
-
-
+ return self._incomplete_size_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def clear(self, name=None):
- """
- Clears the staging area.
+ """Clears the staging area.
Args:
name: A name for the operation (optional)
@@ -2198,10 +2283,12 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_clear" % self._name
- return self._clear_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return self._clear_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
class RecordInput(object):
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index b6b20d1b4a..1f300b7147 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -29,8 +29,10 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("distributions.Bernoulli")
class Bernoulli(distribution.Distribution):
"""Bernoulli distribution.
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 2b93478cdf..6d6b40b045 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -45,6 +46,7 @@ _beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
+@tf_export("distributions.Beta")
class Beta(distribution.Distribution):
"""Beta distribution.
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index 8f6d18d91a..44d64070ce 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -111,6 +112,7 @@ class _Mapping(collections.namedtuple(
@six.add_metaclass(abc.ABCMeta)
+@tf_export("distributions.bijectors.Bijector")
class Bijector(object):
"""Interface for transformations of a `Distribution` sample.
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 84ca6db4c4..9161e3fa9f 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
@@ -58,6 +59,7 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
return event, params
+@tf_export("distributions.Categorical")
class Categorical(distribution.Distribution):
"""Categorical distribution.
@@ -263,7 +265,9 @@ class Categorical(distribution.Distribution):
logits_2d = self.logits
else:
logits_2d = array_ops.reshape(self.logits, [-1, self.event_size])
- draws = random_ops.multinomial(logits_2d, n, seed=seed)
+ sample_dtype = dtypes.int64 if self.dtype.size > 4 else dtypes.int32
+ draws = random_ops.multinomial(
+ logits_2d, n, seed=seed, output_dtype=sample_dtype)
draws = array_ops.reshape(
array_ops.transpose(draws),
array_ops.concat([[n], self.batch_shape_tensor()], 0))
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 2accedf1b9..25afeec936 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -42,6 +43,7 @@ dtype `self.dtype` and be in the `(self.event_shape() - 1)`-simplex, i.e.,
`self.batch_shape() + self.event_shape()`."""
+@tf_export("distributions.Dirichlet")
class Dirichlet(distribution.Distribution):
"""Dirichlet distribution.
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index aa2b511c54..03a98c56ba 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -49,6 +50,7 @@ fractional components, and such that
with `self.concentration` and `self.total_count`."""
+@tf_export("distributions.DirichletMultinomial")
class DirichletMultinomial(distribution.Distribution):
"""Dirichlet-Multinomial compound distribution.
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 2d4c3509bc..4071e50e81 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -197,6 +198,7 @@ class _DistributionMeta(abc.ABCMeta):
return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
+@tf_export("distributions.ReparameterizationType")
class ReparameterizationType(object):
"""Instances of this class represent how sampling is reparameterized.
@@ -239,15 +241,20 @@ class ReparameterizationType(object):
# reparameterized distribution support straight-through gradients with
# respect to all parameters.
FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
+tf_export("distributions.FULLY_REPARAMETERIZED").export_constant(
+ __name__, "FULLY_REPARAMETERIZED")
# Not reparameterized distribution: samples from a non-
# reparameterized distribution do not support straight-through gradients for
# at least some of the parameters.
NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
+tf_export("distributions.NOT_REPARAMETERIZED").export_constant(
+ __name__, "NOT_REPARAMETERIZED")
@six.add_metaclass(_DistributionMeta)
+@tf_export("distributions.Distribution")
class Distribution(_BaseDistribution):
"""A generic probability distribution base class.
@@ -1075,7 +1082,7 @@ class Distribution(_BaseDistribution):
Denote this distribution (`self`) by `p` and the `other` distribution by
`q`. Assuming `p, q` are absolutely continuous with respect to reference
- measure `r`, (Shanon) cross entropy is defined as:
+ measure `r`, the KL divergence is defined as:
```none
KL[p, q] = E_p[log(p(X)/q(X))]
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index 281641b915..6345a76d48 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -35,6 +36,7 @@ __all__ = [
]
+@tf_export("distributions.Exponential")
class Exponential(gamma.Gamma):
"""Exponential distribution.
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 4ac2b9b4ef..8fb218be3a 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -41,6 +42,7 @@ __all__ = [
]
+@tf_export("distributions.Gamma")
class Gamma(distribution.Distribution):
"""Gamma distribution.
diff --git a/tensorflow/python/ops/distributions/identity_bijector.py b/tensorflow/python/ops/distributions/identity_bijector.py
index f277eda8bb..2972c3554b 100644
--- a/tensorflow/python/ops/distributions/identity_bijector.py
+++ b/tensorflow/python/ops/distributions/identity_bijector.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -27,6 +28,7 @@ __all__ = [
]
+@tf_export("distributions.bijectors.Identity")
class Identity(bijector.Bijector):
"""Compute Y = g(X) = X.
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index 829b9611cf..e3c6f3e789 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.tf_export import tf_export
_DIVERGENCES = {}
@@ -50,6 +51,7 @@ def _registered_kl(type_a, type_b):
return kl_fn
+@tf_export("distributions.kl_divergence")
def kl_divergence(distribution_a, distribution_b,
allow_nan_stats=True, name=None):
"""Get the KL-divergence KL(distribution_a || distribution_b).
@@ -142,6 +144,7 @@ def cross_entropy(ref, other,
ref, other, allow_nan_stats=allow_nan_stats)
+@tf_export("distributions.RegisterKL")
class RegisterKL(object):
"""Decorator to register a KL divergence implementation function.
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index 5c964ff78a..e98ac855c5 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -41,6 +42,7 @@ __all__ = [
]
+@tf_export("distributions.Laplace")
class Laplace(distribution.Distribution):
"""The Laplace distribution with location `loc` and `scale` parameters.
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index 04762565c2..26b5c5aef9 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -50,6 +51,7 @@ fractional components, and such that
with `self.probs` and `self.total_count`."""
+@tf_export("distributions.Multinomial")
class Multinomial(distribution.Distribution):
"""Multinomial distribution.
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 0ef1c91df8..e7f120ea2d 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -40,6 +41,7 @@ __all__ = [
]
+@tf_export("distributions.Normal")
class Normal(distribution.Distribution):
"""The Normal distribution with location `loc` and `scale` parameters.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index 073ac4286b..778fefb8c2 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -41,6 +42,7 @@ __all__ = [
]
+@tf_export("distributions.StudentT")
class StudentT(distribution.Distribution):
"""Student's t-distribution.
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index 9b555f87ea..3580af18f2 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,8 +29,10 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("distributions.Uniform")
class Uniform(distribution.Distribution):
"""Uniform distribution with `low` and `high` parameters.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index f4561d1a83..3826585f59 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
def _gather(params, ids, name=None):
@@ -257,6 +258,7 @@ def _embedding_lookup_and_transform(params,
return ret
+@tf_export("nn.embedding_lookup")
def embedding_lookup(
params,
ids,
@@ -325,6 +327,7 @@ def embedding_lookup(
transform_fn=None)
+@tf_export("nn.embedding_lookup_sparse")
def embedding_lookup_sparse(params,
sp_ids,
sp_weights,
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 688512bea6..ac03d30fcd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -44,9 +44,11 @@ from tensorflow.python.ops.gen_functional_ops import *
from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient
# pylint: enable=unused-import
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
+@tf_export("foldl")
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
@@ -134,6 +136,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
return r_a
+@tf_export("foldr")
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
@@ -221,6 +224,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
return r_a
+@tf_export("map_fn")
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -424,6 +428,7 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
return output_pack(results_flat)
+@tf_export("scan")
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""scan on the list of tensors unpacked from `elems` on dimension 0.
@@ -453,7 +458,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
`[i1, i2]` then an appropriate signature for `fn` in `python2` is:
- `fn = lambda (acc_p1, acc_p2), (t1 [t2, t3]):` and `fn` must return a list,
+ `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
`[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
one that works in `python3`, is:
`fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py
index 193046ba70..12afcd0b51 100644
--- a/tensorflow/python/ops/gradient_checker.py
+++ b/tensorflow/python/ops/gradient_checker.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
def _product(t):
@@ -264,6 +265,7 @@ def _compute_gradient_list(x,
return ret
+@tf_export("test.compute_gradient")
def compute_gradient(x,
x_shape,
y,
@@ -325,6 +327,7 @@ def compute_gradient(x,
return ret
+@tf_export("test.compute_gradient_error")
def compute_gradient_error(x,
x_shape,
y,
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 20c7a9fd66..314726ede6 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -50,7 +50,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import tf_logging as logging
-
+from tensorflow.python.util.tf_export import tf_export
# Warn the user if we convert a sparse representation to dense with at
# least this number of elements.
@@ -234,9 +234,10 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
raise TypeError(
"Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
y.dtype)
- new_grad_ys.append(array_ops.fill(
- array_ops.shape(y), constant_op.constant(
- 1, dtype=y.dtype, name="grad_ys_%d" % i)))
+ new_grad_ys.append(
+ array_ops.fill(
+ array_ops.shape(y),
+ constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
continue
if y.dtype.is_floating or y.dtype.is_integer:
if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
@@ -394,6 +395,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
+@tf_export("gradients")
def gradients(ys,
xs,
grad_ys=None,
@@ -490,11 +492,12 @@ def gradients(ys,
name, "gradients",
list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
- xs = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
- else x
- for x in xs]
- xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name="x",
- as_ref=True)
+ xs = [
+ x.handle if isinstance(x, resource_variable_ops.ResourceVariable) else x
+ for x in xs
+ ]
+ xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
+ xs, name="x", as_ref=True)
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
# The approach we take here is as follows: Create a list of all ops in the
@@ -511,9 +514,8 @@ def gradients(ys,
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
- pending_count, loop_state = _PendingCount(ops.get_default_graph(), to_ops,
- from_ops,
- colocate_gradients_with_ops)
+ pending_count, loop_state = _PendingCount(
+ ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
# Iterate over the collected ops.
#
@@ -586,9 +588,8 @@ def gradients(ys,
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
- if (not isinstance(out_grad, ops.Tensor) and
- not out_grad) and ((not grad_fn and is_func_call) or
- _IsTrainable(op.outputs[i])):
+ if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
+ (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
@@ -605,17 +606,17 @@ def gradients(ys,
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
# functions.
- in_grads = _MaybeCompile(
- grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: grad_fn(op, *out_grads))
else:
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
- in_grads = _MaybeCompile(
- grad_scope, op, func_call, lambda: _SymGrad(op, out_grads))
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: _SymGrad(op, out_grads))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len(
- [x for x in in_grads if x is not None]) > 1:
+ if gate_gradients and len([x for x in in_grads
+ if x is not None]) > 1:
with ops.device(None):
with ops.colocate_with(None, ignore_existing=True):
in_grads = control_flow_ops.tuple(in_grads)
@@ -635,8 +636,8 @@ def gradients(ys,
"Incompatible shapes between op input and calculated "
"input gradient. Forward operation: %s. Input index: %d. "
"Original input shape: %s. "
- "Calculated input gradient shape: %s"
- % (op.name, i, t_in.shape, in_grad.shape))
+ "Calculated input gradient shape: %s" %
+ (op.name, i, t_in.shape, in_grad.shape))
_SetGrad(grads, t_in, in_grad)
if loop_state:
loop_state.ExitGradWhileContext(op, before=False)
@@ -668,8 +669,8 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
pending_count[x.op._id] -= 1
ready = (pending_count[x.op._id] == 0)
if loop_state and not ready:
- ready = (pending_count[x.op._id] > 0 and
- control_flow_util.IsLoopSwitch(x.op))
+ ready = (
+ pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
# pylint: enable=protected-access
if ready:
if control_flow_util.IsLoopExit(x.op):
@@ -723,8 +724,8 @@ def _GetGrad(grads, t):
if not op_grads:
return None
t_grad = op_grads[t.value_index]
- assert not isinstance(t_grad, list), (
- "gradients list should have been aggregated by now.")
+ assert not isinstance(
+ t_grad, list), ("gradients list should have been aggregated by now.")
return t_grad
@@ -743,9 +744,8 @@ def _HandleNestedIndexedSlices(grad):
else:
assert isinstance(grad.values, ops.IndexedSlices)
g = _HandleNestedIndexedSlices(grad.values)
- return ops.IndexedSlices(g.values,
- array_ops.gather(grad.indices, g.indices),
- g.dense_shape)
+ return ops.IndexedSlices(g.values, array_ops.gather(
+ grad.indices, g.indices), g.dense_shape)
def _AccumulatorShape(inputs):
@@ -799,6 +799,7 @@ def _MultiDeviceAddN(tensor_list):
return math_ops.add_n(summands)
+@tf_export("AggregationMethod")
class AggregationMethod(object):
"""A class listing aggregation methods used to combine gradients.
@@ -846,8 +847,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
]:
- raise ValueError("Invalid aggregation_method specified %s." %
- aggregation_method)
+ raise ValueError(
+ "Invalid aggregation_method specified %s." % aggregation_method)
out_grads = _GetGrads(grads, op)
for i, out_grad in enumerate(out_grads):
if loop_state:
@@ -856,7 +857,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
continue
# Grads have to be Tensors or IndexedSlices
if (isinstance(out_grad, collections.Sequence) and not all([
- isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in out_grad
+ isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in out_grad
if g is not None
])):
raise TypeError("gradients have to be either all Tensors "
@@ -900,8 +902,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
else:
used = "add_n"
out_grads[i] = _MultiDeviceAddN(out_grad)
- logging.vlog(2, " _AggregatedGrads %d x %s using %s",
- len(out_grad), tensor_shape, used)
+ logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
+ tensor_shape, used)
else:
out_grad = math_ops._as_indexed_slices_list(
[g for g in out_grad if g is not None])
@@ -964,15 +966,21 @@ def _hessian_vector_product(ys, xs, v):
assert len(grads) == length
elemwise_products = [
math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem))
- for grad_elem, v_elem in zip(grads, v) if grad_elem is not None
+ for grad_elem, v_elem in zip(grads, v)
+ if grad_elem is not None
]
# Second backprop
return gradients(elemwise_products, xs)
-def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
- gate_gradients=False, aggregation_method=None):
+@tf_export("hessians")
+def hessians(ys,
+ xs,
+ name="hessians",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None):
"""Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
`hessians()` adds ops to the graph to output the Hessian matrix of `ys`
@@ -1000,9 +1008,9 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
"""
xs = _AsList(xs)
kwargs = {
- 'colocate_gradients_with_ops': colocate_gradients_with_ops,
- 'gate_gradients': gate_gradients,
- 'aggregation_method': aggregation_method
+ "colocate_gradients_with_ops": colocate_gradients_with_ops,
+ "gate_gradients": gate_gradients,
+ "aggregation_method": aggregation_method
}
# Compute first-order derivatives and iterate for each x in xs.
hessians = []
@@ -1027,8 +1035,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
)
_shape = array_ops.shape(x)
- _reshaped_hessian = array_ops.reshape(
- hessian.stack(), array_ops.concat((_shape, _shape), 0)
- )
+ _reshaped_hessian = array_ops.reshape(hessian.stack(),
+ array_ops.concat((_shape, _shape), 0))
hessians.append(_reshaped_hessian)
return hessians
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index 51e4be9343..f079e56b10 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -17,6 +17,7 @@
Please see @{$python/histogram_ops} guide.
+@@histogram_fixed_width_bins
@@histogram_fixed_width
"""
@@ -30,8 +31,75 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+def histogram_fixed_width_bins(values,
+ value_range,
+ nbins=100,
+ dtype=dtypes.int32,
+ name=None):
+ """Bins the given values for use in a histogram.
+
+ Given the tensor `values`, this operation returns a rank 1 `Tensor`
+ representing the indices of a histogram into which each element
+ of `values` would be binned. The bins are equal width and
+ determined by the arguments `value_range` and `nbins`.
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+ values <= value_range[0] will be mapped to hist[0],
+ values >= value_range[1] will be mapped to hist[-1].
+ nbins: Scalar `int32 Tensor`. Number of histogram bins.
+ dtype: dtype for returned histogram.
+ name: A name for this operation (defaults to 'histogram_fixed_width').
+
+ Returns:
+ A `Tensor` holding the indices of the binned values whose shape matches
+ `values`.
+
+ Examples:
+
+ ```python
+ # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ nbins = 5
+ value_range = [0.0, 5.0]
+ new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+
+ with tf.get_default_session() as sess:
+ indices = tf.histogram_fixed_width_bins(new_values, value_range, nbins=5)
+ variables.global_variables_initializer().run()
+ sess.run(indices) => [0, 0, 1, 2, 4]
+ ```
+ """
+ with ops.name_scope(name, 'histogram_fixed_width_bins',
+ [values, value_range, nbins]):
+ values = ops.convert_to_tensor(values, name='values')
+ shape = array_ops.shape(values)
+
+ values = array_ops.reshape(values, [-1])
+ value_range = ops.convert_to_tensor(value_range, name='value_range')
+ nbins = ops.convert_to_tensor(nbins, dtype=dtypes.int32, name='nbins')
+ nbins_float = math_ops.cast(nbins, values.dtype)
+
+ # Map tensor values that fall within value_range to [0, 1].
+ scaled_values = math_ops.truediv(
+ values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
+
+ # map tensor values within the open interval value_range to {0,.., nbins-1},
+ # values outside the open interval will be zero or less, or nbins or more.
+ indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+ # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+ indices = math_ops.cast(
+ clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+ return array_ops.reshape(indices, shape)
+
+
+@tf_export('histogram_fixed_width')
def histogram_fixed_width(values,
value_range,
nbins=100,
@@ -71,5 +139,5 @@ def histogram_fixed_width(values,
"""
with ops.name_scope(name, 'histogram_fixed_width',
[values, value_range, nbins]) as name:
- return gen_math_ops._histogram_fixed_width(values, value_range, nbins,
- dtype=dtype, name=name)
+ return gen_math_ops._histogram_fixed_width( # pylint: disable=protected-access
+ values, value_range, nbins, dtype=dtype, name=name)
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 19ad6cd2ba..a226ac81bb 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -21,11 +21,64 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import histogram_ops
from tensorflow.python.platform import test
+class BinValuesFixedWidth(test.TestCase):
+
+ def test_empty_input_gives_all_zero_counts(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = []
+ expected_bins = []
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_1d_values_int32_output(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+ expected_bins = [0, 0, 1, 2, 4, 4]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5, dtype=dtypes.int64)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_1d_float64_values_int32_output(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = np.float64([0.0, 5.0])
+ values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
+ expected_bins = [0, 0, 1, 2, 4, 4]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_2d_values(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = constant_op.constant(
+ [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]], shape=(2, 3))
+ expected_bins = [[0, 0, 1], [2, 4, 4]]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+
class HistogramFixedWidthTest(test.TestCase):
def setUp(self):
@@ -87,8 +140,8 @@ class HistogramFixedWidthTest(test.TestCase):
self.assertEqual(dtypes.int32, hist.dtype)
self.assertAllClose(expected_bin_counts, hist.eval())
- hist = histogram_ops.histogram_fixed_width(values, value_range,
- nbins=placeholder)
+ hist = histogram_ops.histogram_fixed_width(
+ values, value_range, nbins=placeholder)
self.assertEquals(hist.shape.ndims, 1)
self.assertIs(hist.shape[0].value, None)
self.assertEqual(dtypes.int32, hist.dtype)
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 3b0b5a978c..de12c5f63f 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -49,6 +49,10 @@ See the @{$python/image} guide.
@@grayscale_to_rgb
@@hsv_to_rgb
@@rgb_to_hsv
+@@rgb_to_yiq
+@@yiq_to_rgb
+@@rgb_to_yuv
+@@yuv_to_rgb
@@convert_image_dtype
@@adjust_brightness
@@random_brightness
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 9f09d0a4d1..cab1025df1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Implementation of image ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -28,7 +25,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
@@ -36,7 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
-
+from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable('RandomCrop')
# TODO(b/31222613): This op may be differentiable, and there may be
@@ -109,8 +105,9 @@ def _ImageDimensions(image, rank):
else:
static_shape = image.get_shape().with_rank(rank).as_list()
dynamic_shape = array_ops.unstack(array_ops.shape(image), rank)
- return [s if s is not None else d
- for s, d in zip(static_shape, dynamic_shape)]
+ return [
+ s if s is not None else d for s, d in zip(static_shape, dynamic_shape)
+ ]
def _Check3DImage(image, require_static=True):
@@ -131,22 +128,45 @@ def _Check3DImage(image, require_static=True):
try:
image_shape = image.get_shape().with_rank(3)
except ValueError:
- raise ValueError("'image' (shape %s) must be three-dimensional." %
- image.shape)
+ raise ValueError(
+ "'image' (shape %s) must be three-dimensional." % image.shape)
if require_static and not image_shape.is_fully_defined():
- raise ValueError("'image' (shape %s) must be fully defined." %
- image_shape)
+ raise ValueError("'image' (shape %s) must be fully defined." % image_shape)
if any(x == 0 for x in image_shape):
- raise ValueError("all dims of 'image.shape' must be > 0: %s" %
- image_shape)
+ raise ValueError("all dims of 'image.shape' must be > 0: %s" % image_shape)
if not image_shape.is_fully_defined():
- return [check_ops.assert_positive(array_ops.shape(image),
- ["all dims of 'image.shape' "
- "must be > 0."])]
+ return [
+ check_ops.assert_positive(
+ array_ops.shape(image),
+ ["all dims of 'image.shape' "
+ 'must be > 0.'])
+ ]
else:
return []
+def _Assert3DImage(image):
+ """Assert that we are working with a properly shaped image.
+
+ Performs the check statically if possible (i.e. if the shape
+ is statically known). Otherwise adds a control dependency
+ to an assert op that checks the dynamic shape.
+
+ Args:
+ image: 3-D Tensor of shape [height, width, channels]
+
+ Raises:
+ ValueError: if `image.shape` is not a 3-vector.
+
+ Returns:
+ If the shape of `image` could be verified statically, `image` is
+ returned unchanged, otherwise there will be a control dependency
+ added that asserts the correct dynamic shape.
+ """
+ return control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
+
+
def _CheckAtLeast3DImage(image, require_static=True):
"""Assert that we are working with properly shaped image.
@@ -172,12 +192,15 @@ def _CheckAtLeast3DImage(image, require_static=True):
if require_static and not image_shape.is_fully_defined():
raise ValueError('\'image\' must be fully defined.')
if any(x == 0 for x in image_shape):
- raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
- image_shape)
+ raise ValueError(
+ 'all dims of \'image.shape\' must be > 0: %s' % image_shape)
if not image_shape.is_fully_defined():
- return [check_ops.assert_positive(array_ops.shape(image),
- ["all dims of 'image.shape' "
- "must be > 0."])]
+ return [
+ check_ops.assert_positive(
+ array_ops.shape(image),
+ ["all dims of 'image.shape' "
+ 'must be > 0.'])
+ ]
else:
return []
@@ -201,6 +224,7 @@ def fix_image_flip_shape(image, result):
return result
+@tf_export('image.random_flip_up_down')
def random_flip_up_down(image, seed=None):
"""Randomly flips an image vertically (upside down).
@@ -221,17 +245,18 @@ def random_flip_up_down(image, seed=None):
"""
with ops.name_scope(None, 'random_flip_up_down', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(mirror_cond,
- lambda: array_ops.reverse(image, [0]),
- lambda: image,
- name=scope)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [0]),
+ lambda: image,
+ name=scope)
return fix_image_flip_shape(image, result)
+@tf_export('image.random_flip_left_right')
def random_flip_left_right(image, seed=None):
"""Randomly flip an image horizontally (left to right).
@@ -252,17 +277,18 @@ def random_flip_left_right(image, seed=None):
"""
with ops.name_scope(None, 'random_flip_left_right', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(mirror_cond,
- lambda: array_ops.reverse(image, [1]),
- lambda: image,
- name=scope)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [1]),
+ lambda: image,
+ name=scope)
return fix_image_flip_shape(image, result)
+@tf_export('image.flip_left_right')
def flip_left_right(image):
"""Flip an image horizontally (left to right).
@@ -282,12 +308,12 @@ def flip_left_right(image):
"""
with ops.name_scope(None, 'flip_left_right', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
- return fix_image_flip_shape(image,
- array_ops.reverse(image, [1], name=scope))
+ image = _Assert3DImage(image)
+ return fix_image_flip_shape(image, array_ops.reverse(
+ image, [1], name=scope))
+@tf_export('image.flip_up_down')
def flip_up_down(image):
"""Flip an image vertically (upside down).
@@ -307,12 +333,12 @@ def flip_up_down(image):
"""
with ops.name_scope(None, 'flip_up_down', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
- return fix_image_flip_shape(image,
- array_ops.reverse(image, [0], name=scope))
+ image = _Assert3DImage(image)
+ return fix_image_flip_shape(image, array_ops.reverse(
+ image, [0], name=scope))
+@tf_export('image.rot90')
def rot90(image, k=1, name=None):
"""Rotate an image counter-clockwise by 90 degrees.
@@ -326,30 +352,30 @@ def rot90(image, k=1, name=None):
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)
def _rot90():
- return array_ops.transpose(array_ops.reverse_v2(image, [1]),
- [1, 0, 2])
+ return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2])
+
def _rot180():
return array_ops.reverse_v2(image, [0, 1])
+
def _rot270():
- return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
- [1])
- cases = [(math_ops.equal(k, 1), _rot90),
- (math_ops.equal(k, 2), _rot180),
+ return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1])
+
+ cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
(math_ops.equal(k, 3), _rot270)]
- ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
- name=scope)
+ ret = control_flow_ops.case(
+ cases, default=lambda: image, exclusive=True, name=scope)
ret.set_shape([None, None, image.get_shape()[2]])
return ret
+@tf_export('image.transpose_image')
def transpose_image(image):
"""Transpose an image by swapping the first and second dimension.
@@ -366,11 +392,11 @@ def transpose_image(image):
"""
with ops.name_scope(None, 'transpose_image', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
return array_ops.transpose(image, [1, 0, 2], name=scope)
+@tf_export('image.central_crop')
def central_crop(image, central_fraction):
"""Crop the central region of the image.
@@ -402,8 +428,7 @@ def central_crop(image, central_fraction):
if central_fraction == 1.0:
return image
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
img_shape = array_ops.shape(image)
depth = image.get_shape()[2]
@@ -424,6 +449,7 @@ def central_crop(image, central_fraction):
return image
+@tf_export('image.pad_to_bounding_box')
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
target_width):
"""Pad `image` with zeros to the specified `height` and `width`.
@@ -494,8 +520,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
]), [4, 2])
padded = array_ops.pad(image, paddings)
- padded_shape = [None if _is_tensor(i) else i
- for i in [batch, target_height, target_width, depth]]
+ padded_shape = [
+ None if _is_tensor(i) else i
+ for i in [batch, target_height, target_width, depth]
+ ]
padded.set_shape(padded_shape)
if not is_batch:
@@ -504,6 +532,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
return padded
+@tf_export('image.crop_to_bounding_box')
def crop_to_bounding_box(image, offset_height, offset_width, target_height,
target_width):
"""Crops an image to a specified bounding box.
@@ -568,12 +597,13 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
image = control_flow_ops.with_dependencies(assert_ops, image)
cropped = array_ops.slice(
- image,
- array_ops.stack([0, offset_height, offset_width, 0]),
+ image, array_ops.stack([0, offset_height, offset_width, 0]),
array_ops.stack([-1, target_height, target_width, -1]))
- cropped_shape = [None if _is_tensor(i) else i
- for i in [batch, target_height, target_width, depth]]
+ cropped_shape = [
+ None if _is_tensor(i) else i
+ for i in [batch, target_height, target_width, depth]
+ ]
cropped.set_shape(cropped_shape)
if not is_batch:
@@ -582,6 +612,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
return cropped
+@tf_export('image.resize_image_with_crop_or_pad')
def resize_image_with_crop_or_pad(image, target_height, target_width):
"""Crops and/or pads an image to a target width and height.
@@ -637,8 +668,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
target_height = control_flow_ops.with_dependencies(
assert_ops, target_height)
if _is_tensor(target_width):
- target_width = control_flow_ops.with_dependencies(
- assert_ops, target_width)
+ target_width = control_flow_ops.with_dependencies(assert_ops,
+ target_width)
def max_(x, y):
if _is_tensor(x) or _is_tensor(y):
@@ -683,10 +714,12 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
_, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4)
assert_ops = []
- assert_ops += _assert(equal_(resized_height, target_height), ValueError,
- 'resized height is not correct.')
- assert_ops += _assert(equal_(resized_width, target_width), ValueError,
- 'resized width is not correct.')
+ assert_ops += _assert(
+ equal_(resized_height, target_height), ValueError,
+ 'resized height is not correct.')
+ assert_ops += _assert(
+ equal_(resized_width, target_width), ValueError,
+ 'resized width is not correct.')
resized = control_flow_ops.with_dependencies(assert_ops, resized)
@@ -696,6 +729,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
return resized
+@tf_export('image.ResizeMethod')
class ResizeMethod(object):
BILINEAR = 0
NEAREST_NEIGHBOR = 1
@@ -703,6 +737,7 @@ class ResizeMethod(object):
AREA = 3
+@tf_export('image.resize_images')
def resize_images(images,
size,
method=ResizeMethod.BILINEAR,
@@ -735,8 +770,9 @@ def resize_images(images,
size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`.
- align_corners: bool. If true, exactly align all 4 corners of the input and
- output. Defaults to `false`.
+ align_corners: bool. If True, the centers of the 4 corner pixels of the
+ input and output tensors are aligned, preserving the values at the
+ corner pixels. Defaults to `False`.
Raises:
ValueError: if the shape of `images` is incompatible with the
@@ -785,22 +821,17 @@ def resize_images(images,
return images
if method == ResizeMethod.BILINEAR:
- images = gen_image_ops.resize_bilinear(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_bilinear(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.NEAREST_NEIGHBOR:
- images = gen_image_ops.resize_nearest_neighbor(images,
- size,
- align_corners=
- align_corners)
+ images = gen_image_ops.resize_nearest_neighbor(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.BICUBIC:
- images = gen_image_ops.resize_bicubic(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_bicubic(
+ images, size, align_corners=align_corners)
elif method == ResizeMethod.AREA:
- images = gen_image_ops.resize_area(images,
- size,
- align_corners=align_corners)
+ images = gen_image_ops.resize_area(
+ images, size, align_corners=align_corners)
else:
raise ValueError('Resize method is not implemented.')
@@ -813,6 +844,7 @@ def resize_images(images,
return images
+@tf_export('image.per_image_standardization')
def per_image_standardization(image):
"""Linearly scales `image` to have zero mean and unit norm.
@@ -834,15 +866,15 @@ def per_image_standardization(image):
"""
with ops.name_scope(None, 'per_image_standardization', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
num_pixels = math_ops.reduce_prod(array_ops.shape(image))
image = math_ops.cast(image, dtype=dtypes.float32)
image_mean = math_ops.reduce_mean(image)
- variance = (math_ops.reduce_mean(math_ops.square(image)) -
- math_ops.square(image_mean))
+ variance = (
+ math_ops.reduce_mean(math_ops.square(image)) -
+ math_ops.square(image_mean))
variance = gen_nn_ops.relu(variance)
stddev = math_ops.sqrt(variance)
@@ -856,6 +888,7 @@ def per_image_standardization(image):
return image
+@tf_export('image.random_brightness')
def random_brightness(image, max_delta, seed=None):
"""Adjust the brightness of images by a random factor.
@@ -882,6 +915,7 @@ def random_brightness(image, max_delta, seed=None):
return adjust_brightness(image, delta)
+@tf_export('image.random_contrast')
def random_contrast(image, lower, upper, seed=None):
"""Adjust the contrast of an image by a random factor.
@@ -913,6 +947,7 @@ def random_contrast(image, lower, upper, seed=None):
return adjust_contrast(image, contrast_factor)
+@tf_export('image.adjust_brightness')
def adjust_brightness(image, delta):
"""Adjust the brightness of RGB or Grayscale images.
@@ -940,13 +975,13 @@ def adjust_brightness(image, delta):
orig_dtype = image.dtype
flt_image = convert_image_dtype(image, dtypes.float32)
- adjusted = math_ops.add(flt_image,
- math_ops.cast(delta, dtypes.float32),
- name=name)
+ adjusted = math_ops.add(
+ flt_image, math_ops.cast(delta, dtypes.float32), name=name)
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
+@tf_export('image.adjust_contrast')
def adjust_contrast(images, contrast_factor):
"""Adjust contrast of RGB or grayscale images.
@@ -980,14 +1015,14 @@ def adjust_contrast(images, contrast_factor):
flt_images = convert_image_dtype(images, dtypes.float32)
# pylint: disable=protected-access
- adjusted = gen_image_ops._adjust_contrastv2(flt_images,
- contrast_factor=contrast_factor,
- name=name)
+ adjusted = gen_image_ops._adjust_contrastv2(
+ flt_images, contrast_factor=contrast_factor, name=name)
# pylint: enable=protected-access
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
+@tf_export('image.adjust_gamma')
def adjust_gamma(image, gamma=1, gain=1):
"""Performs Gamma Correction on the input image.
@@ -1026,16 +1061,17 @@ def adjust_gamma(image, gamma=1, gain=1):
'Gamma should be a non-negative real number.')
if assert_op:
gamma = control_flow_ops.with_dependencies(assert_op, gamma)
-
+
# scale = max(dtype) - min(dtype).
- scale = constant_op.constant(image.dtype.limits[1] - image.dtype.limits[0],
- dtype=dtypes.float32)
+ scale = constant_op.constant(
+ image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32)
# According to the definition of gamma correction.
- adjusted_img = (img / scale) ** gamma * scale * gain
+ adjusted_img = (img / scale)**gamma * scale * gain
return adjusted_img
+@tf_export('image.convert_image_dtype')
def convert_image_dtype(image, dtype, saturate=False, name=None):
"""Convert `image` to `dtype`, scaling its values if needed.
@@ -1114,6 +1150,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None):
return math_ops.cast(scaled, dtype, name=name)
+@tf_export('image.rgb_to_grayscale')
def rgb_to_grayscale(images, name=None):
"""Converts one or more images from RGB to Grayscale.
@@ -1143,6 +1180,7 @@ def rgb_to_grayscale(images, name=None):
return convert_image_dtype(gray_float, orig_dtype, name=name)
+@tf_export('image.grayscale_to_rgb')
def grayscale_to_rgb(images, name=None):
"""Converts one or more images from Grayscale to RGB.
@@ -1159,9 +1197,8 @@ def grayscale_to_rgb(images, name=None):
with ops.name_scope(name, 'grayscale_to_rgb', [images]) as name:
images = ops.convert_to_tensor(images, name='images')
rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0)
- shape_list = (
- [array_ops.ones(rank_1,
- dtype=dtypes.int32)] + [array_ops.expand_dims(3, 0)])
+ shape_list = ([array_ops.ones(rank_1, dtype=dtypes.int32)] +
+ [array_ops.expand_dims(3, 0)])
multiples = array_ops.concat(shape_list, 0)
rgb = array_ops.tile(images, multiples, name=name)
rgb.set_shape(images.get_shape()[:-1].concatenate([3]))
@@ -1169,6 +1206,7 @@ def grayscale_to_rgb(images, name=None):
# pylint: disable=invalid-name
+@tf_export('image.random_hue')
def random_hue(image, max_delta, seed=None):
"""Adjust the hue of an RGB image by a random factor.
@@ -1201,6 +1239,7 @@ def random_hue(image, max_delta, seed=None):
return adjust_hue(image, delta)
+@tf_export('image.adjust_hue')
def adjust_hue(image, delta, name=None):
"""Adjust hue of an RGB image.
@@ -1234,6 +1273,7 @@ def adjust_hue(image, delta, name=None):
return convert_image_dtype(rgb_altered, orig_dtype)
+@tf_export('image.random_saturation')
def random_saturation(image, lower, upper, seed=None):
"""Adjust the saturation of an RGB image by a random factor.
@@ -1266,6 +1306,7 @@ def random_saturation(image, lower, upper, seed=None):
return adjust_saturation(image, saturation_factor)
+@tf_export('image.adjust_saturation')
def adjust_saturation(image, saturation_factor, name=None):
"""Adjust saturation of an RGB image.
@@ -1297,6 +1338,8 @@ def adjust_saturation(image, saturation_factor, name=None):
gen_image_ops.adjust_saturation(flt_image, saturation_factor),
orig_dtype)
+
+@tf_export('image.decode_image')
def decode_image(contents, channels=None, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.
@@ -1351,8 +1394,7 @@ def decode_image(contents, channels=None, name=None):
gif_channels = 0 if channels is None else channels
good_channels = math_ops.logical_and(
math_ops.not_equal(gif_channels, 1, name='check_gif_channels'),
- math_ops.not_equal(gif_channels, 4, name='check_gif_channels')
- )
+ math_ops.not_equal(gif_channels, 4, name='check_gif_channels'))
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
@@ -1375,8 +1417,8 @@ def decode_image(contents, channels=None, name=None):
def _jpeg():
"""Decodes a jpeg image."""
jpeg_channels = 0 if channels is None else channels
- good_channels = math_ops.not_equal(jpeg_channels, 4,
- name='check_jpeg_channels')
+ good_channels = math_ops.not_equal(
+ jpeg_channels, 4, name='check_jpeg_channels')
channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG '
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
@@ -1389,6 +1431,7 @@ def decode_image(contents, channels=None, name=None):
return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg')
+@tf_export('image.total_variation')
def total_variation(images, name=None):
"""Calculate and return the total variation for one or more images.
@@ -1453,15 +1496,21 @@ def total_variation(images, name=None):
# Calculate the total variation by taking the absolute value of the
# pixel-differences and summing over the appropriate axis.
- tot_var = (math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
- math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
+ tot_var = (
+ math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) +
+ math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis))
return tot_var
-def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
- seed2=None, min_object_covered=None,
- aspect_ratio_range=None, area_range=None,
+@tf_export('image.sample_distorted_bounding_box')
+def sample_distorted_bounding_box(image_size,
+ bounding_boxes,
+ seed=None,
+ seed2=None,
+ min_object_covered=0.1,
+ aspect_ratio_range=None,
+ area_range=None,
max_attempts=None,
use_image_if_no_bounding_boxes=None,
name=None):
@@ -1477,10 +1526,12 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
The output of this Op is a single bounding box that may be used to crop the
original image. The output is returned as 3 tensors: `begin`, `size` and
`bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
- image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+ image. The latter may be supplied to `tf.image.draw_bounding_boxes` to
+ visualize
what the bounding box looks like.
- Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+ Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`.
+ The
bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
height of the underlying image.
@@ -1508,23 +1559,27 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
false and no bounding boxes are supplied, an error is raised.
Args:
- image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`.
+ image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
+ `int16`, `int32`, `int64`.
1-D, containing `[height, width, channels]`.
bounding_boxes: A `Tensor` of type `float32`.
3-D with shape `[batch, N, 4]` describing the N bounding boxes
associated with the image.
seed: An optional `int`. Defaults to `0`.
If either `seed` or `seed2` are set to non-zero, the random number
- generator is seeded by the given `seed`. Otherwise, it is seeded by a random
+ generator is seeded by the given `seed`. Otherwise, it is seeded by a
+ random
seed.
seed2: An optional `int`. Defaults to `0`.
A second seed to avoid seed collision.
min_object_covered: A Tensor of type `float32`. Defaults to `0.1`.
The cropped area of the image must contain at least this
- fraction of any bounding box supplied. The value of this parameter should be
+ fraction of any bounding box supplied. The value of this parameter should
+ be
non-negative. In the case of 0, the cropped area does not need to overlap
any of the bounding boxes supplied.
- aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 1.33]`.
+ aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75,
+ 1.33]`.
The cropped area of the image must have an aspect ratio =
width / height within this range.
area_range: An optional list of `floats`. Defaults to `[0.05, 1]`.
@@ -1532,34 +1587,44 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None,
supplied image within in this range.
max_attempts: An optional `int`. Defaults to `100`.
Number of attempts at generating a cropped region of the image
- of the specified constraints. After `max_attempts` failures, return the entire
+ of the specified constraints. After `max_attempts` failures, return the
+ entire
image.
use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`.
Controls behavior if no bounding boxes supplied.
- If true, assume an implicit bounding box covering the whole input. If false,
+ If true, assume an implicit bounding box covering the whole input. If
+ false,
raise an error.
name: A name for the operation (optional).
Returns:
A tuple of `Tensor` objects (begin, size, bboxes).
- begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+ begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[offset_height, offset_width, 0]`. Provide as input to
`tf.slice`.
- size: A `Tensor`. Has the same type as `image_size`. 1-D, containing `[target_height, target_width, -1]`. Provide as input to
+ size: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[target_height, target_width, -1]`. Provide as input to
`tf.slice`.
- bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+ bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing
+ the distorted bounding box.
Provide as input to `tf.image.draw_bounding_boxes`.
"""
with ops.name_scope(name, 'sample_distorted_bounding_box'):
- return gen_image_ops._sample_distorted_bounding_box_v2(image_size,
- bounding_boxes, seed=seed,
- seed2=seed2, min_object_covered=min_object_covered,
- aspect_ratio_range=aspect_ratio_range, area_range=area_range,
- max_attempts=max_attempts,
- use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
- name=name)
-
-
+ return gen_image_ops._sample_distorted_bounding_box_v2( # pylint: disable=protected-access
+ image_size,
+ bounding_boxes,
+ seed=seed,
+ seed2=seed2,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes,
+ name=name)
+
+
+@tf_export('image.non_max_suppression')
def non_max_suppression(boxes,
scores,
max_output_size,
@@ -1604,3 +1669,106 @@ def non_max_suppression(boxes,
return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size,
iou_threshold)
# pylint: enable=protected-access
+
+
+_rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115],
+ [0.587, -0.27455667, -0.52273617],
+ [0.114, -0.32134392, 0.31119955]]
+
+
+def rgb_to_yiq(images):
+ """Converts one or more images from RGB to YIQ.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YIQ
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(_rgb_to_yiq_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]])
+
+
+_yiq_to_rgb_kernel = [[1, 1, 1],
+ [0.95598634, -0.27201283, -1.10674021],
+ [0.6208248, -0.64720424, 1.70423049]]
+
+
+def yiq_to_rgb(images):
+ """Converts one or more images from YIQ to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(_yiq_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]])
+
+
+_rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538],
+ [0.587, -0.28886916, -0.51496512],
+ [0.114, 0.43601035, -0.10001026]]
+
+
+def rgb_to_yuv(images):
+ """Converts one or more images from RGB to YUV.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YUV
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(_rgb_to_yuv_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]])
+
+
+_yuv_to_rgb_kernel = [[1, 1, 1],
+ [0, -0.394642334, 2.03206185],
+ [1.13988303, -0.58062185, 0]]
+
+
+def yuv_to_rgb(images):
+ """Converts one or more images from YUV to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ U and V value are in [-0.5,0.5].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(_yuv_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]])
+
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 3a49d41c9e..9834384634 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -85,6 +85,64 @@ class RGBToHSVTest(test_util.TensorFlowTestCase):
self.assertAllClose(rgb_tf, rgb_np)
+class RGBToYIQTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YIQ and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yiq(batch0)
+ batch2 = image_ops.yiq_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yiq, split0))
+ split2 = list(map(image_ops.yiq_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
+class RGBToYUVTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YUV and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yuv(batch0)
+ batch2 = image_ops.yuv_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yuv, split0))
+ split2 = list(map(image_ops.yuv_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
def _RGBToGrayscale(self, images):
@@ -1857,6 +1915,25 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase):
self.assertAllEqual([3], end.get_shape().as_list())
self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+ def testDefaultMinObjectCovered(self):
+ # By default min_object_covered=0.1 if not provided
+ with self.test_session(use_gpu=True):
+ image_size = constant_op.constant(
+ [40, 50, 1], shape=[3], dtype=dtypes.int32)
+ bounding_box = constant_op.constant(
+ [0.0, 0.0, 1.0, 1.0],
+ shape=[4],
+ dtype=dtypes.float32,)
+ begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box(
+ image_size=image_size,
+ bounding_boxes=bounding_box,
+ aspect_ratio_range=(0.75, 1.33),
+ area_range=(0.05, 1.0))
+
+ self.assertAllEqual([3], begin.get_shape().as_list())
+ self.assertAllEqual([3], end.get_shape().as_list())
+ self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+
class ResizeImagesTest(test_util.TensorFlowTestCase):
@@ -2833,6 +2910,16 @@ class PngTest(test_util.TensorFlowTestCase):
class GifTest(test_util.TensorFlowTestCase):
+ def testOptimizedGifErrorString(self):
+ filename = "tensorflow/core/lib/gif/testdata/optimized.gif"
+
+ with self.test_session(use_gpu=True) as sess:
+ gif = io_ops.read_file(filename)
+ image = image_ops.decode_gif(gif)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, "can't process optimized gif"):
+ gif, image = sess.run([gif, image])
+
def testValid(self):
# Read some real GIFs
prefix = "tensorflow/core/lib/gif/testdata/"
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 5dc43d65b9..c7502d0fda 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -44,8 +44,10 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("keras.initializers.Initializer")
class Initializer(object):
"""Initializer base class: all initializers inherit from this class.
"""
@@ -83,6 +85,8 @@ class Initializer(object):
return cls(**config)
+@tf_export("keras.initializers.Zeros", "initializers.zeros",
+ "zeros_initializer")
class Zeros(Initializer):
"""Initializer that generates tensors initialized to 0."""
@@ -98,6 +102,7 @@ class Zeros(Initializer):
return {"dtype": self.dtype.name}
+@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer")
class Ones(Initializer):
"""Initializer that generates tensors initialized to 1."""
@@ -113,6 +118,8 @@ class Ones(Initializer):
return {"dtype": self.dtype.name}
+@tf_export("keras.initializers.Constant", "initializers.constant",
+ "constant_initializer")
class Constant(Initializer):
"""Initializer that generates tensors with constant values.
@@ -217,6 +224,8 @@ class Constant(Initializer):
return {"value": self.value, "dtype": self.dtype.name}
+@tf_export("keras.initializers.RandomUniform", "initializers.random_uniform",
+ "random_uniform_initializer")
class RandomUniform(Initializer):
"""Initializer that generates tensors with a uniform distribution.
@@ -252,6 +261,8 @@ class RandomUniform(Initializer):
}
+@tf_export("keras.initializers.RandomNormal", "initializers.random_normal",
+ "random_normal_initializer")
class RandomNormal(Initializer):
"""Initializer that generates tensors with a normal distribution.
@@ -287,6 +298,8 @@ class RandomNormal(Initializer):
}
+@tf_export("keras.initializers.TruncatedNormal",
+ "initializers.truncated_normal", "truncated_normal_initializer")
class TruncatedNormal(Initializer):
"""Initializer that generates a truncated normal distribution.
@@ -327,6 +340,8 @@ class TruncatedNormal(Initializer):
}
+@tf_export("initializers.uniform_unit_scaling",
+ "uniform_unit_scaling_initializer")
class UniformUnitScaling(Initializer):
"""Initializer that generates tensors without scaling variance.
@@ -385,6 +400,8 @@ class UniformUnitScaling(Initializer):
return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
+@tf_export("keras.initializers.VarianceScaling",
+ "initializers.variance_scaling", "variance_scaling_initializer")
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
@@ -464,6 +481,8 @@ class VarianceScaling(Initializer):
}
+@tf_export("keras.initializers.Orthogonal", "initializers.orthogonal",
+ "orthogonal_initializer")
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
@@ -523,6 +542,7 @@ class Orthogonal(Initializer):
return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
+@tf_export("keras.initializers.Identity", "initializers.identity")
class Identity(Initializer):
"""Initializer that generates the identity matrix.
@@ -570,6 +590,7 @@ identity_initializer = Identity
# pylint: enable=invalid-name
+@tf_export("glorot_uniform_initializer")
def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
@@ -593,6 +614,7 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
+@tf_export("glorot_normal_initializer")
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot normal initializer, also called Xavier normal initializer.
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 670bb9a9c2..5e70b3186f 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -79,6 +79,7 @@ from tensorflow.python.ops import gen_io_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_io_ops import *
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -140,6 +141,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
preferred_shard, name=name)
+@tf_export("ReaderBase")
class ReaderBase(object):
"""Base class for different Reader types, that produce a record every step.
@@ -354,6 +356,7 @@ ops.NotDifferentiable("ReaderRestoreState")
ops.NotDifferentiable("ReaderReset")
+@tf_export("WholeFileReader")
class WholeFileReader(ReaderBase):
"""A Reader that outputs the entire contents of a file as a value.
@@ -381,6 +384,7 @@ class WholeFileReader(ReaderBase):
ops.NotDifferentiable("WholeFileReader")
+@tf_export("TextLineReader")
class TextLineReader(ReaderBase):
"""A Reader that outputs the lines of a file delimited by newlines.
@@ -410,6 +414,7 @@ class TextLineReader(ReaderBase):
ops.NotDifferentiable("TextLineReader")
+@tf_export("FixedLengthRecordReader")
class FixedLengthRecordReader(ReaderBase):
"""A Reader that outputs fixed-length records from a file.
@@ -452,6 +457,7 @@ class FixedLengthRecordReader(ReaderBase):
ops.NotDifferentiable("FixedLengthRecordReader")
+@tf_export("TFRecordReader")
class TFRecordReader(ReaderBase):
"""A Reader that outputs the records from a TFRecords file.
@@ -482,6 +488,7 @@ class TFRecordReader(ReaderBase):
ops.NotDifferentiable("TFRecordReader")
+@tf_export("LMDBReader")
class LMDBReader(ReaderBase):
"""A Reader that outputs the records from a LMDB file.
@@ -506,6 +513,7 @@ class LMDBReader(ReaderBase):
ops.NotDifferentiable("LMDBReader")
+@tf_export("IdentityReader")
class IdentityReader(ReaderBase):
"""A Reader that outputs the queued work as both the key and value.
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 13a32c83d9..3cbbf3412a 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -277,20 +277,28 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
# https://j-towns.github.io/papers/svd-derivative.pdf
a = op.inputs[0]
a_shape = a.get_shape().with_rank_at_least(2)
+ grad_s_mat = array_ops.matrix_diag(grad_s)
- if op.get_attr("compute_uv"):
- # TODO(rmlarsen): Make this work with complex types.
- if a.dtype.is_complex:
- raise NotImplementedError(
- "SVD gradient is not implemented for complex types and "
- "compute_uv=True.")
- grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
- grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
- m = a_shape[-2].merge_with(grad_u_shape[-2])
- n = a_shape[-1].merge_with(grad_v_shape[-2])
- batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
- grad_v_shape[:-2])
- a_shape = batch_shape.concatenate([m, n])
+ if not op.get_attr("compute_uv"):
+ s, u, v = linalg_ops.svd(a, compute_uv=True)
+ grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
+ grad_a.set_shape(a_shape)
+ return grad_a
+
+ full_matrices = op.get_attr("full_matrices")
+
+ # TODO(rmlarsen): Make this work with complex types.
+ if a.dtype.is_complex:
+ raise NotImplementedError(
+ "SVD gradient is not implemented for complex types and "
+ "compute_uv=True.")
+ grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
+ grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
+ m = a_shape[-2].merge_with(grad_u_shape[-2])
+ n = a_shape[-1].merge_with(grad_v_shape[-2])
+ batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
+ grad_v_shape[:-2])
+ a_shape = batch_shape.concatenate([m, n])
m = a_shape[-2].value
n = a_shape[-1].value
@@ -300,12 +308,9 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
"SVD gradient has not been implemented for input with unknown "
"inner matrix shape.")
- if not op.get_attr("compute_uv"):
- s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
- else:
- s = op.outputs[0]
- u = op.outputs[1]
- v = op.outputs[2]
+ s = op.outputs[0]
+ u = op.outputs[1]
+ v = op.outputs[2]
use_adjoint = False
if m > n:
@@ -317,19 +322,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
grad_u, grad_v = grad_v, grad_u
with ops.control_dependencies([grad_s, grad_u, grad_v]):
- grad_s_mat = array_ops.matrix_diag(grad_s)
- if not op.get_attr("compute_uv"):
- if use_adjoint:
- grad_a = math_ops.matmul(
- v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True)
- else:
- grad_a = math_ops.matmul(u,
- math_ops.matmul(
- grad_s_mat, v[..., :, :m], adjoint_b=True))
- grad_a.set_shape(a_shape)
- return grad_a
-
- if op.get_attr("full_matrices") and abs(m - n) > 1:
+ if full_matrices and abs(m - n) > 1:
raise NotImplementedError(
"svd gradient is not implemented for abs(m - n) > 1 "
"when full_matrices is True")
@@ -371,7 +364,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
gv1t_v1 = math_ops.matmul(gv1t, v1)
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
- if op.get_attr("full_matrices"):
+ if full_matrices:
v2 = v[..., :, m:n]
grad_v2 = grad_v[..., :, m:n]
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index be9beee633..9803eed6ae 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# Names below are lower_case.
# pylint: disable=invalid-name
@@ -77,6 +78,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
return gen_linalg_ops.cholesky(gramian)
+@tf_export('cholesky_solve', 'linalg.cholesky_solve')
def cholesky_solve(chol, rhs, name=None):
"""Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
@@ -119,6 +121,7 @@ def cholesky_solve(chol, rhs, name=None):
return x
+@tf_export('eye', 'linalg.eye')
def eye(num_rows,
num_columns=None,
batch_shape=None,
@@ -188,6 +191,7 @@ def eye(num_rows,
return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+@tf_export('matrix_solve_ls', 'linalg.lstsq')
def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
r"""Solves one or more linear least-squares problems.
@@ -324,6 +328,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
# pylint: enable=protected-access
+@tf_export('self_adjoint_eig', 'linalg.eigh')
def self_adjoint_eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of self-adjoint matrices.
@@ -346,6 +351,7 @@ def self_adjoint_eig(tensor, name=None):
return e, v
+@tf_export('self_adjoint_eigvals', 'linalg.eigvalsh')
def self_adjoint_eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more self-adjoint matrices.
@@ -368,6 +374,7 @@ def self_adjoint_eigvals(tensor, name=None):
return e
+@tf_export('svd', 'linalg.svd')
def svd(tensor, full_matrices=False, compute_uv=True, name=None):
r"""Computes the singular value decompositions of one or more matrices.
@@ -439,6 +446,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None):
# pylint: disable=redefined-builtin
+@tf_export('norm', 'linalg.norm')
@deprecation.deprecated_args(
None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims')
def norm(tensor,
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index 6b31c00639..bba59ebcef 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -19,7 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -28,28 +30,30 @@ from tensorflow.python.ops.gen_list_ops import *
@ops.RegisterGradient("TensorListPushBack")
-def _PushBackGradient(op, dresult):
+def _PushBackGrad(op, dresult):
return gen_list_ops.tensor_list_pop_back(
dresult, element_dtype=op.get_attr("element_dtype"))
@ops.RegisterGradient("TensorListPopBack")
-def _PopBackGradient(unused_op, dlist, delement):
+def _PopBackGrad(op, dlist, delement):
if dlist is None:
dlist = gen_list_ops.empty_tensor_list(
element_dtype=delement.dtype,
- element_shape=-1)
+ element_shape=gen_list_ops.tensor_list_element_shape(
+ op.outputs[0], shape_type=dtypes.int32))
return gen_list_ops.tensor_list_push_back(dlist, delement)
@ops.RegisterGradient("TensorListStack")
-def _TensorListStack(unused_op, dtensor):
+def _TensorListStackGrad(unused_op, dtensor):
return gen_list_ops.tensor_list_from_tensor(dtensor,
element_shape=dtensor.shape[1:])
@ops.RegisterGradient("TensorListFromTensor")
-def _TensorListFromTensor(op, dlist):
+def _TensorListFromTensorGrad(op, dlist):
+ """Gradient for TensorListFromTensor."""
if op.inputs[0].shape[0] is not None:
num_elements = op.inputs[0].shape[0]
else:
@@ -57,7 +61,34 @@ def _TensorListFromTensor(op, dlist):
if dlist is None:
dlist = gen_list_ops.empty_tensor_list(
element_dtype=op.inputs[0].dtype,
- element_shape=-1)
+ element_shape=gen_list_ops.tensor_list_element_shape(
+ op.outputs[0], shape_type=dtypes.int32))
return gen_list_ops.tensor_list_stack(
dlist, element_dtype=op.inputs[0].dtype,
num_elements=num_elements)
+
+
+@ops.RegisterGradient("TensorListGetItem")
+def _TensorListGetItemGrad(op, ditem):
+ """Gradient for TensorListGetItem."""
+ list_size = gen_list_ops.tensor_list_length(op.inputs[0])
+ list_grad = gen_list_ops.tensor_list_set_item(
+ gen_list_ops.tensor_list_reserve(
+ gen_list_ops.tensor_list_element_shape(op.inputs[0],
+ shape_type=dtypes.int32),
+ list_size, element_dtype=ditem.dtype),
+ index=op.inputs[1],
+ item=ditem)
+ index_grad = None
+ return list_grad, index_grad
+
+
+@ops.RegisterGradient("TensorListSetItem")
+def _TensorListSetItemGrad(op, dlist):
+ _, index, item = op.inputs
+ list_grad = gen_list_ops.tensor_list_set_item(
+ dlist, index=index, item=array_ops.zeros_like(item))
+ index_grad = None
+ element_grad = gen_list_ops.tensor_list_get_item(
+ dlist, index, element_dtype=item.dtype)
+ return list_grad, index_grad, element_grad
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index 51ab2aec22..eadbc1b7c3 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
# The python wrapper for Assert is in control_flow_ops, as the Assert
# call relies on certain conditionals for its dependencies. Use
@@ -35,6 +36,7 @@ from tensorflow.python.util.deprecation import deprecated
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
+@tf_export("Print")
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 333e36873a..f539a7bb68 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -40,8 +40,10 @@ from tensorflow.python.ops.gen_lookup_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("initialize_all_tables")
@deprecated(None, "Use `tf.tables_initializer` instead.")
def initialize_all_tables(name="init_all_tables"):
"""Returns an Op that initializes all tables of the default graph.
@@ -56,6 +58,7 @@ def initialize_all_tables(name="init_all_tables"):
return tables_initializer(name)
+@tf_export("tables_initializer")
def tables_initializer(name="init_all_tables"):
"""Returns an Op that initializes all tables of the default graph.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index cfdfa09757..b8e8207bb2 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -172,6 +172,7 @@ from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# Aliases for some automatically-generated names.
linspace = gen_math_ops.lin_space
@@ -190,6 +191,7 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin
+@tf_export("argmax")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -209,6 +211,7 @@ def argmax(input,
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
+@tf_export("argmin")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -233,6 +236,7 @@ def argmin(input,
# pylint: disable=anomalous-backslash-in-string,protected-access
# pylint: disable=g-docstring-has-escape
+@tf_export("abs")
def abs(x, name=None):
r"""Computes the absolute value of a tensor.
@@ -307,6 +311,7 @@ class DivideDelegateWithName(object):
return _div_python2(self.x, y, self.name)
+@tf_export("divide")
def divide(x, y, name=None):
"""Computes Python style division of `x` by `y`."""
@@ -318,6 +323,7 @@ def divide(x, y, name=None):
return x / y
+@tf_export("multiply")
def multiply(x, y, name=None):
return gen_math_ops._mul(x, y, name)
@@ -337,6 +343,7 @@ _mul.__doc__ = (
gen_math_ops._mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
+@tf_export("subtract")
def subtract(x, y, name=None):
return gen_math_ops._sub(x, y, name)
@@ -357,6 +364,7 @@ _sub.__doc__ = (
# pylint: disable=g-docstring-has-escape
+@tf_export("negative")
def negative(x, name=None):
"""Computes numerical negative value element-wise.
@@ -405,6 +413,7 @@ def _neg(x, name=None):
# pylint: enable=g-docstring-has-escape
+@tf_export("sign")
def sign(x, name=None):
"""Returns an element-wise indication of the sign of a number.
@@ -435,6 +444,7 @@ def sign(x, name=None):
return gen_math_ops.sign(x, name=name)
+@tf_export("square")
def square(x, name=None):
r"""Computes square of x element-wise.
@@ -457,6 +467,7 @@ def square(x, name=None):
return gen_math_ops.square(x, name=name)
+@tf_export("sqrt")
def sqrt(x, name=None):
r"""Computes square root of x element-wise.
@@ -479,6 +490,7 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
+@tf_export("erf")
def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
@@ -499,6 +511,7 @@ def erf(x, name=None):
return gen_math_ops.erf(x, name=name)
+@tf_export("scalar_mul")
def scalar_mul(scalar, x):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@@ -528,6 +541,7 @@ def scalar_mul(scalar, x):
raise ValueError("Only scalar multiply works, got shape %s" % shape)
+@tf_export("pow")
def pow(x, y, name=None):
r"""Computes the power of one value to another.
@@ -555,6 +569,7 @@ def pow(x, y, name=None):
# pylint: disable=redefined-builtin,redefined-outer-name
+@tf_export("complex")
def complex(real, imag, name=None):
r"""Converts two real numbers to a complex number.
@@ -596,6 +611,7 @@ def complex(real, imag, name=None):
return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
+@tf_export("real")
def real(input, name=None):
r"""Returns the real part of a complex (or real) tensor.
@@ -626,6 +642,7 @@ def real(input, name=None):
return input
+@tf_export("imag")
def imag(input, name=None):
r"""Returns the imaginary part of a complex (or real) tensor.
@@ -655,6 +672,7 @@ def imag(input, name=None):
return array_ops.zeros_like(input)
+@tf_export("angle")
def angle(input, name=None):
r"""Returns the element-wise argument of a complex (or real) tensor.
@@ -693,6 +711,7 @@ def angle(input, name=None):
# pylint: enable=redefined-outer-name,redefined-builtin
+@tf_export("round")
def round(x, name=None):
"""Rounds the values of a tensor to the nearest integer, element-wise.
@@ -719,6 +738,7 @@ def round(x, name=None):
return gen_math_ops.round(x, name=name)
+@tf_export("cast")
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
@@ -759,6 +779,7 @@ def cast(x, dtype, name=None):
return gen_math_ops.cast(x, base_type, name=name)
+@tf_export("saturate_cast")
def saturate_cast(value, dtype, name=None):
"""Performs a safe saturating cast of `value` to `dtype`.
@@ -792,6 +813,7 @@ def saturate_cast(value, dtype, name=None):
return cast(value, dtype, name=name)
+@tf_export("to_float")
def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
@@ -808,6 +830,7 @@ def to_float(x, name="ToFloat"):
return cast(x, dtypes.float32, name=name)
+@tf_export("to_double")
def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`.
@@ -824,6 +847,7 @@ def to_double(x, name="ToDouble"):
return cast(x, dtypes.float64, name=name)
+@tf_export("to_int32")
def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`.
@@ -840,6 +864,7 @@ def to_int32(x, name="ToInt32"):
return cast(x, dtypes.int32, name=name)
+@tf_export("to_int64")
def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`.
@@ -856,6 +881,7 @@ def to_int64(x, name="ToInt64"):
return cast(x, dtypes.int64, name=name)
+@tf_export("to_bfloat16")
def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`.
@@ -1029,6 +1055,7 @@ def _div_python2(x, y, name=None):
return gen_math_ops._floor_div(x, y, name=name)
+@tf_export("truediv")
def truediv(x, y, name=None):
"""Divides x / y elementwise (using Python 3 division operator semantics).
@@ -1060,6 +1087,7 @@ def truediv(x, y, name=None):
return _truediv_python3(x, y, name)
+@tf_export("div")
def div(x, y, name=None):
"""Divides x / y elementwise (using Python 2 division operator semantics).
@@ -1087,6 +1115,7 @@ mod = gen_math_ops._floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses
# tf.truncatediv
+@tf_export("floordiv")
def floordiv(x, y, name=None):
"""Divides `x / y` elementwise, rounding toward the most negative integer.
@@ -1157,6 +1186,7 @@ _OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")
+@tf_export("logical_xor")
def logical_xor(x, y, name="LogicalXor"):
"""x ^ y = (x | y) & ~(x & y)."""
# TODO(alemi) Make this a cwise op if people end up relying on it.
@@ -1176,6 +1206,7 @@ ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
+@tf_export("range")
def range(start, limit=None, delta=1, dtype=None, name="range"):
"""Creates a sequence of numbers.
@@ -1281,6 +1312,7 @@ def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
return output
+@tf_export("reduce_sum")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_sum(input_tensor,
@@ -1341,6 +1373,7 @@ def reduce_sum(input_tensor,
name=name))
+@tf_export("count_nonzero")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def count_nonzero(input_tensor,
@@ -1407,6 +1440,7 @@ def count_nonzero(input_tensor,
dtype=dtype)
+@tf_export("reduce_mean")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_mean(input_tensor,
@@ -1478,6 +1512,7 @@ def reduce_mean(input_tensor,
name=name))
+@tf_export("reduce_prod")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_prod(input_tensor,
@@ -1527,6 +1562,7 @@ def reduce_prod(input_tensor,
name=name))
+@tf_export("reduce_min")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_min(input_tensor,
@@ -1575,6 +1611,7 @@ def reduce_min(input_tensor,
name=name))
+@tf_export("reduce_max")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_max(input_tensor,
@@ -1623,6 +1660,7 @@ def reduce_max(input_tensor,
name=name))
+@tf_export("reduce_all")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_all(input_tensor,
@@ -1680,6 +1718,7 @@ def reduce_all(input_tensor,
name=name))
+@tf_export("reduce_any")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_any(input_tensor,
@@ -1737,6 +1776,7 @@ def reduce_any(input_tensor,
name=name))
+@tf_export("reduce_logsumexp")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
@@ -1810,6 +1850,7 @@ def reduce_logsumexp(input_tensor,
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
+@tf_export("trace", "linalg.trace")
def trace(x, name=None):
"""Compute the trace of a tensor `x`.
@@ -1851,6 +1892,7 @@ def trace(x, name=None):
return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
+@tf_export("matmul")
def matmul(a,
b,
transpose_a=False,
@@ -2103,6 +2145,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
return casted_outputs
+@tf_export("add_n")
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
@@ -2132,6 +2175,7 @@ def add_n(inputs, name=None):
return gen_math_ops._add_n(inputs, name=name)
+@tf_export("accumulate_n")
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
"""Returns the element-wise sum of a list of tensors.
@@ -2216,6 +2260,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
ref, var_name=var.op.name, name=name)
+@tf_export("nn.sigmoid", "sigmoid")
def sigmoid(x, name=None):
"""Computes sigmoid of `x` element-wise.
@@ -2238,6 +2283,7 @@ def sigmoid(x, name=None):
return gen_math_ops._sigmoid(x, name=name)
+@tf_export("log_sigmoid")
def log_sigmoid(x, name=None):
"""Computes log sigmoid of `x` element-wise.
@@ -2256,6 +2302,7 @@ def log_sigmoid(x, name=None):
return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name)
+@tf_export("nn.tanh", "tanh")
def tanh(x, name=None):
"""Computes hyperbolic tangent of `x` element-wise.
@@ -2276,6 +2323,7 @@ def tanh(x, name=None):
return gen_math_ops._tanh(x, name=name)
+@tf_export("bincount")
def bincount(arr,
weights=None,
minlength=None,
@@ -2322,6 +2370,7 @@ def bincount(arr,
return gen_math_ops.bincount(arr, output_size, weights)
+@tf_export("cumsum")
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative sum of the tensor `x` along `axis`.
@@ -2373,6 +2422,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
+@tf_export("cumprod")
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative product of the tensor `x` along `axis`.
@@ -2424,6 +2474,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
+@tf_export("conj")
def conj(x, name=None):
r"""Returns the complex conjugate of a complex number.
@@ -2502,6 +2553,7 @@ def reduced_shape(input_shape, axes):
]) # [1, 1]
+@tf_export("sparse_segment_sum")
def sparse_segment_sum(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor.
@@ -2576,6 +2628,7 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
name=name)
+@tf_export("sparse_segment_mean")
def sparse_segment_mean(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the mean along sparse segments of a tensor.
@@ -2619,6 +2672,7 @@ def sparse_segment_mean(data, indices, segment_ids, name=None,
name=name)
+@tf_export("sparse_segment_sqrt_n")
def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
@@ -2655,6 +2709,7 @@ def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
name=name)
+@tf_export("tensordot", "linalg.tensordot")
def tensordot(a, b, axes, name=None):
r"""Tensor contraction of a and b along specified axes.
@@ -2772,10 +2827,14 @@ def tensordot(a, b, axes, name=None):
"""Generates two sets of contraction axes for the two tensor arguments."""
a_shape = a.get_shape()
if isinstance(axes, compat.integral_types):
- if axes < 1:
- raise ValueError("'axes' must be at least 1.")
+ if axes < 0:
+ raise ValueError("'axes' must be at least 0.")
if a_shape.ndims is not None:
- return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
+ if axes > a_shape.ndims:
+ raise ValueError("'axes' must not be larger than the number of "
+ "dimensions of tensor %s." % a)
+ return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
+ list(xrange(axes)))
else:
rank = array_ops.rank(a)
return (range(rank - axes, rank, dtype=dtypes.int32),
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 25e1613a65..7776ff08c4 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
def metric_variable(shape, dtype, validate_shape=True, name=None):
@@ -99,27 +100,29 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
# Use dynamic rank.
weights_rank_tensor = array_ops.rank(weights)
rank_diff = weights_rank_tensor - array_ops.rank(predictions)
+
def _maybe_expand_weights():
return control_flow_ops.cond(
math_ops.equal(rank_diff, -1),
- lambda: array_ops.expand_dims(weights, [-1]),
- lambda: weights)
+ lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
+
# Don't attempt squeeze if it will fail based on static check.
if ((weights_rank is not None) and
(not weights_shape.dims[-1].is_compatible_with(1))):
maybe_squeeze_weights = lambda: weights
else:
maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
+
def _maybe_adjust_weights():
return control_flow_ops.cond(
- math_ops.equal(rank_diff, 1),
- maybe_squeeze_weights,
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
_maybe_expand_weights)
+
# If weights are scalar, do nothing. Otherwise, try to add or remove a
# dimension to match predictions.
weights = control_flow_ops.cond(
- math_ops.equal(weights_rank_tensor, 0),
- lambda: weights, _maybe_adjust_weights)
+ math_ops.equal(weights_rank_tensor, 0), lambda: weights,
+ _maybe_adjust_weights)
return predictions, labels, weights
@@ -164,14 +167,14 @@ def _maybe_expand_labels(labels, predictions):
if predictions_rank == labels_rank + 1:
return array_ops.expand_dims(labels, -1, name=scope)
raise ValueError(
- 'Unexpected labels shape %s for predictions shape %s.' % (
- labels.get_shape(), predictions.get_shape()))
+ 'Unexpected labels shape %s for predictions shape %s.' %
+ (labels.get_shape(), predictions.get_shape()))
# Otherwise, use dynamic shape.
return control_flow_ops.cond(
- math_ops.equal(array_ops.rank(predictions), array_ops.rank(labels) + 1),
- lambda: array_ops.expand_dims(labels, -1, name=scope),
- lambda: labels)
+ math_ops.equal(array_ops.rank(predictions),
+ array_ops.rank(labels) + 1),
+ lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
def _safe_div(numerator, denominator, name):
@@ -262,8 +265,12 @@ def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
return total_cm, update_op
-def mean(values, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+@tf_export('metrics.mean')
+def mean(values,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the (weighted) mean of the given values.
The `mean` function creates two local variables, `total` and `count`
@@ -337,8 +344,13 @@ def mean(values, weights=None, metrics_collections=None,
return mean_t, update_op
-def accuracy(labels, predictions, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+@tf_export('metrics.accuracy')
+def accuracy(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Calculates how often `predictions` matches `labels`.
The `accuracy` function creates two local variables, `total` and
@@ -392,12 +404,15 @@ def accuracy(labels, predictions, weights=None, metrics_collections=None,
if labels.dtype != predictions.dtype:
predictions = math_ops.cast(predictions, labels.dtype)
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
- return mean(is_correct, weights, metrics_collections,
- updates_collections, name or 'accuracy')
+ return mean(is_correct, weights, metrics_collections, updates_collections,
+ name or 'accuracy')
-def _confusion_matrix_at_thresholds(
- labels, predictions, thresholds, weights=None, includes=None):
+def _confusion_matrix_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
+ includes=None):
"""Computes true_positives, false_negatives, true_negatives, false_positives.
This function creates up to four local variables, `true_positives`,
@@ -495,8 +510,8 @@ def _confusion_matrix_at_thresholds(
if weights is not None:
weights = weights_broadcast_ops.broadcast_weights(
math_ops.to_float(weights), predictions)
- weights_tiled = array_ops.tile(array_ops.reshape(
- weights, [1, -1]), [num_thresholds, 1])
+ weights_tiled = array_ops.tile(
+ array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
thresh_tiled.get_shape().assert_is_compatible_with(
weights_tiled.get_shape())
else:
@@ -512,8 +527,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_pos, pred_is_pos))
if weights_tiled is not None:
is_true_positive *= weights_tiled
- update_ops['tp'] = state_ops.assign_add(
- true_p, math_ops.reduce_sum(is_true_positive, 1))
+ update_ops['tp'] = state_ops.assign_add(true_p,
+ math_ops.reduce_sum(
+ is_true_positive, 1))
values['tp'] = true_p
if 'fn' in includes:
@@ -523,8 +539,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_pos, pred_is_neg))
if weights_tiled is not None:
is_false_negative *= weights_tiled
- update_ops['fn'] = state_ops.assign_add(
- false_n, math_ops.reduce_sum(is_false_negative, 1))
+ update_ops['fn'] = state_ops.assign_add(false_n,
+ math_ops.reduce_sum(
+ is_false_negative, 1))
values['fn'] = false_n
if 'tn' in includes:
@@ -534,8 +551,9 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_neg, pred_is_neg))
if weights_tiled is not None:
is_true_negative *= weights_tiled
- update_ops['tn'] = state_ops.assign_add(
- true_n, math_ops.reduce_sum(is_true_negative, 1))
+ update_ops['tn'] = state_ops.assign_add(true_n,
+ math_ops.reduce_sum(
+ is_true_negative, 1))
values['tn'] = true_n
if 'fp' in includes:
@@ -545,16 +563,24 @@ def _confusion_matrix_at_thresholds(
math_ops.logical_and(label_is_neg, pred_is_pos))
if weights_tiled is not None:
is_false_positive *= weights_tiled
- update_ops['fp'] = state_ops.assign_add(
- false_p, math_ops.reduce_sum(is_false_positive, 1))
+ update_ops['fp'] = state_ops.assign_add(false_p,
+ math_ops.reduce_sum(
+ is_false_positive, 1))
values['fp'] = false_p
return values, update_ops
-def auc(labels, predictions, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None,
- curve='ROC', name=None, summation_method='trapezoidal'):
+@tf_export('metrics.auc')
+def auc(labels,
+ predictions,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ curve='ROC',
+ name=None,
+ summation_method='trapezoidal'):
"""Computes the approximate AUC via a Riemann sum.
The `auc` function creates four local variables, `true_positives`,
@@ -622,14 +648,14 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
raise RuntimeError('tf.metrics.auc is not supported when eager execution '
'is enabled.')
- with variable_scope.variable_scope(
- name, 'auc', (labels, predictions, weights)):
+ with variable_scope.variable_scope(name, 'auc',
+ (labels, predictions, weights)):
if curve != 'ROC' and curve != 'PR':
- raise ValueError('curve must be either ROC or PR, %s unknown' %
- (curve))
+ raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -637,6 +663,7 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
# Add epsilons to avoid dividing by 0.
epsilon = 1.0e-6
+
def compute_auc(tp, fn, tn, fp, name):
"""Computes the roc-auc or pr-auc based on confusion counts."""
rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
@@ -667,11 +694,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- auc_value = compute_auc(
- values['tp'], values['fn'], values['tn'], values['fp'], 'value')
- update_op = compute_auc(
- update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'],
- 'update_op')
+ auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+ values['fp'], 'value')
+ update_op = compute_auc(update_ops['tp'], update_ops['fn'],
+ update_ops['tn'], update_ops['fp'], 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, auc_value)
@@ -682,7 +708,10 @@ def auc(labels, predictions, weights=None, num_thresholds=200,
return auc_value, update_op
-def mean_absolute_error(labels, predictions, weights=None,
+@tf_export('metrics.mean_absolute_error')
+def mean_absolute_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -740,7 +769,11 @@ def mean_absolute_error(labels, predictions, weights=None,
updates_collections, name or 'mean_absolute_error')
-def mean_cosine_distance(labels, predictions, dim, weights=None,
+@tf_export('metrics.mean_cosine_distance')
+def mean_cosine_distance(labels,
+ predictions,
+ dim,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -796,10 +829,8 @@ def mean_cosine_distance(labels, predictions, dim, weights=None,
radial_diffs, reduction_indices=[
dim,
], keepdims=True)
- mean_distance, update_op = mean(radial_diffs, weights,
- None,
- None,
- name or 'mean_cosine_distance')
+ mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
+ 'mean_cosine_distance')
mean_distance = math_ops.subtract(1.0, mean_distance)
update_op = math_ops.subtract(1.0, update_op)
@@ -812,6 +843,7 @@ def mean_cosine_distance(labels, predictions, dim, weights=None,
return mean_distance, update_op
+@tf_export('metrics.mean_per_class_accuracy')
def mean_per_class_accuracy(labels,
predictions,
num_classes,
@@ -824,8 +856,8 @@ def mean_per_class_accuracy(labels,
Calculates the accuracy for each class, then takes the mean of that.
For estimation of the metric over a stream of data, the function creates an
- `update_op` operation that updates these variables and returns the
- `mean_accuracy`.
+ `update_op` operation that updates the accuracy of each class and returns
+ them.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
@@ -836,8 +868,8 @@ def mean_per_class_accuracy(labels,
shape is [batch size] and type `int32` or `int64`. The tensor will be
flattened if its rank > 1.
num_classes: The possible number of labels the prediction task can
- have. This value must be provided, since a confusion matrix of
- dimension = [num_classes, num_classes] will be allocated.
+ have. This value must be provided, since two variables with shape =
+ [num_classes] will be allocated.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
@@ -850,7 +882,7 @@ def mean_per_class_accuracy(labels,
Returns:
mean_accuracy: A `Tensor` representing the mean per class accuracy.
- update_op: An operation that increments the confusion matrix.
+ update_op: An operation that updates the accuracy tensor.
Raises:
ValueError: If `predictions` and `labels` have mismatched shapes, or if
@@ -865,27 +897,43 @@ def mean_per_class_accuracy(labels,
with variable_scope.variable_scope(name, 'mean_accuracy',
(predictions, labels, weights)):
+ labels = math_ops.to_int64(labels)
+
+ # Flatten the input if its rank > 1.
+ if labels.get_shape().ndims > 1:
+ labels = array_ops.reshape(labels, [-1])
+
+ if predictions.get_shape().ndims > 1:
+ predictions = array_ops.reshape(predictions, [-1])
+
# Check if shape is compatible.
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- total_cm, update_op = _streaming_confusion_matrix(
- labels, predictions, num_classes, weights=weights)
+ total = metric_variable([num_classes], dtypes.float32, name='total')
+ count = metric_variable([num_classes], dtypes.float32, name='count')
- def compute_mean_accuracy(name):
- """Compute the mean per class accuracy via the confusion matrix."""
- per_row_sum = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
- cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
- denominator = per_row_sum
+ ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
- # If the value of the denominator is 0, set it to 1 to avoid
- # zero division.
- denominator = array_ops.where(
- math_ops.greater(denominator, 0), denominator,
- array_ops.ones_like(denominator))
- accuracies = math_ops.div(cm_diag, denominator)
- return math_ops.reduce_mean(accuracies, name=name)
+ if labels.dtype != predictions.dtype:
+ predictions = math_ops.cast(predictions, labels.dtype)
+ is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
+
+ if weights is not None:
+ if weights.get_shape().ndims > 1:
+ weights = array_ops.reshape(weights, [-1])
+ weights = math_ops.to_float(weights)
- mean_accuracy_v = compute_mean_accuracy('mean_accuracy')
+ is_correct = is_correct * weights
+ ones = ones * weights
+
+ update_total_op = state_ops.scatter_add(total, labels, ones)
+ update_count_op = state_ops.scatter_add(count, labels, is_correct)
+
+ per_class_accuracy = _safe_div(count, total, None)
+
+ mean_accuracy_v = math_ops.reduce_mean(
+ per_class_accuracy, name='mean_accuracy')
+ update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean_accuracy_v)
@@ -896,6 +944,7 @@ def mean_per_class_accuracy(labels,
return mean_accuracy_v, update_op
+@tf_export('metrics.mean_iou')
def mean_iou(labels,
predictions,
num_classes,
@@ -951,13 +1000,14 @@ def mean_iou(labels,
raise RuntimeError('tf.metrics.mean_iou is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'mean_iou', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'mean_iou',
+ (predictions, labels, weights)):
# Check if shape is compatible.
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
+
def compute_mean_iou(name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
@@ -968,22 +1018,21 @@ def mean_iou(labels,
# The mean is only computed over classes that appear in the
# label or prediction tensor. If the denominator is 0, we need to
# ignore the class.
- num_valid_entries = math_ops.reduce_sum(math_ops.cast(
- math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
+ num_valid_entries = math_ops.reduce_sum(
+ math_ops.cast(
+ math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
# If the value of the denominator is 0, set it to 1 to avoid
# zero division.
denominator = array_ops.where(
- math_ops.greater(denominator, 0),
- denominator,
+ math_ops.greater(denominator, 0), denominator,
array_ops.ones_like(denominator))
iou = math_ops.div(cm_diag, denominator)
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
math_ops.greater(num_valid_entries, 0),
- math_ops.reduce_sum(iou, name=name) / num_valid_entries,
- 0)
+ math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
return result
mean_iou_v = compute_mean_iou('mean_iou')
@@ -997,7 +1046,11 @@ def mean_iou(labels,
return mean_iou_v, update_op
-def mean_relative_error(labels, predictions, normalizer, weights=None,
+@tf_export('metrics.mean_relative_error')
+def mean_relative_error(labels,
+ predictions,
+ normalizer,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1056,14 +1109,16 @@ def mean_relative_error(labels, predictions, normalizer, weights=None,
predictions, normalizer)
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = array_ops.where(
- math_ops.equal(normalizer, 0.0),
- array_ops.zeros_like(labels),
+ math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
math_ops.div(math_ops.abs(labels - predictions), normalizer))
return mean(relative_errors, weights, metrics_collections,
updates_collections, name or 'mean_relative_error')
-def mean_squared_error(labels, predictions, weights=None,
+@tf_export('metrics.mean_squared_error')
+def mean_squared_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1117,12 +1172,16 @@ def mean_squared_error(labels, predictions, weights=None,
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
squared_error = math_ops.square(labels - predictions)
- return mean(squared_error, weights, metrics_collections,
- updates_collections, name or 'mean_squared_error')
+ return mean(squared_error, weights, metrics_collections, updates_collections,
+ name or 'mean_squared_error')
-def mean_tensor(values, weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+@tf_export('metrics.mean_tensor')
+def mean_tensor(values,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the element-wise (weighted) mean of the given tensors.
In contrast to the `mean` function which returns a scalar with the
@@ -1189,9 +1248,8 @@ def mean_tensor(values, weights=None, metrics_collections=None,
update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(total, count, name):
- non_zero_count = math_ops.maximum(count,
- array_ops.ones_like(count),
- name=name)
+ non_zero_count = math_ops.maximum(
+ count, array_ops.ones_like(count), name=name)
return math_ops.truediv(total, non_zero_count, name=name)
mean_t = compute_mean(total, count, 'value')
@@ -1206,7 +1264,10 @@ def mean_tensor(values, weights=None, metrics_collections=None,
return mean_t, update_op
-def percentage_below(values, threshold, weights=None,
+@tf_export('metrics.percentage_below')
+def percentage_below(values,
+ threshold,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1253,14 +1314,13 @@ def percentage_below(values, threshold, weights=None,
'eager execution is enabled.')
is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
- return mean(is_below_threshold,
- weights,
- metrics_collections,
- updates_collections,
- name or 'percentage_below_threshold')
+ return mean(is_below_threshold, weights, metrics_collections,
+ updates_collections, name or 'percentage_below_threshold')
-def _count_condition(values, weights=None, metrics_collections=None,
+def _count_condition(values,
+ weights=None,
+ metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
@@ -1290,8 +1350,8 @@ def _count_condition(values, weights=None, metrics_collections=None,
values = math_ops.to_float(values)
if weights is not None:
- with ops.control_dependencies((
- check_ops.assert_rank_in(weights, (0, array_ops.rank(values))),)):
+ with ops.control_dependencies((check_ops.assert_rank_in(
+ weights, (0, array_ops.rank(values))),)):
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
@@ -1307,7 +1367,10 @@ def _count_condition(values, weights=None, metrics_collections=None,
return value_tensor, update_op
-def false_negatives(labels, predictions, weights=None,
+@tf_export('metrics.false_negatives')
+def false_negatives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1343,20 +1406,24 @@ def false_negatives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.false_negatives is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'false_negatives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'false_negatives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_false_negative = math_ops.logical_and(math_ops.equal(labels, True),
- math_ops.equal(predictions, False))
+ is_false_negative = math_ops.logical_and(
+ math_ops.equal(labels, True), math_ops.equal(predictions, False))
return _count_condition(is_false_negative, weights, metrics_collections,
updates_collections)
-def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+@tf_export('metrics.false_negatives_at_thresholds')
+def false_negatives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1409,7 +1476,10 @@ def false_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
return values['fn'], update_ops['fn']
-def false_positives(labels, predictions, weights=None,
+@tf_export('metrics.false_positives')
+def false_positives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1446,20 +1516,24 @@ def false_positives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'false_positives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'false_positives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_false_positive = math_ops.logical_and(math_ops.equal(labels, False),
- math_ops.equal(predictions, True))
+ is_false_positive = math_ops.logical_and(
+ math_ops.equal(labels, False), math_ops.equal(predictions, True))
return _count_condition(is_false_positive, weights, metrics_collections,
updates_collections)
-def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+@tf_export('metrics.false_positives_at_thresholds')
+def false_positives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1512,7 +1586,10 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None,
return values['fp'], update_ops['fp']
-def true_negatives(labels, predictions, weights=None,
+@tf_export('metrics.true_negatives')
+def true_negatives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1549,20 +1626,24 @@ def true_negatives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.true_negatives is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'true_negatives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'true_negatives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_true_negative = math_ops.logical_and(math_ops.equal(labels, False),
- math_ops.equal(predictions, False))
+ is_true_negative = math_ops.logical_and(
+ math_ops.equal(labels, False), math_ops.equal(predictions, False))
return _count_condition(is_true_negative, weights, metrics_collections,
updates_collections)
-def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
+@tf_export('metrics.true_negatives_at_thresholds')
+def true_negatives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1615,7 +1696,10 @@ def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None,
return values['tn'], update_ops['tn']
-def true_positives(labels, predictions, weights=None,
+@tf_export('metrics.true_positives')
+def true_positives(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1652,20 +1736,24 @@ def true_positives(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'true_positives', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'true_positives',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
- is_true_positive = math_ops.logical_and(math_ops.equal(labels, True),
- math_ops.equal(predictions, True))
+ is_true_positive = math_ops.logical_and(
+ math_ops.equal(labels, True), math_ops.equal(predictions, True))
return _count_condition(is_true_positive, weights, metrics_collections,
updates_collections)
-def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
+@tf_export('metrics.true_positives_at_thresholds')
+def true_positives_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -1718,8 +1806,12 @@ def true_positives_at_thresholds(labels, predictions, thresholds, weights=None,
return values['tp'], update_ops['tp']
-def precision(labels, predictions, weights=None,
- metrics_collections=None, updates_collections=None,
+@tf_export('metrics.precision')
+def precision(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
name=None):
"""Computes the precision of the predictions with respect to the labels.
@@ -1768,8 +1860,8 @@ def precision(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.precision is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'precision', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'precision',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
@@ -1777,22 +1869,27 @@ def precision(labels, predictions, weights=None,
weights=weights)
true_p, true_positives_update_op = true_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
false_p, false_positives_update_op = false_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
def compute_precision(tp, fp, name):
return array_ops.where(
- math_ops.greater(tp + fp, 0),
- math_ops.div(tp, tp + fp),
- 0,
- name)
+ math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
p = compute_precision(true_p, false_p, 'value')
- update_op = compute_precision(
- true_positives_update_op, false_positives_update_op, 'update_op')
+ update_op = compute_precision(true_positives_update_op,
+ false_positives_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, p)
@@ -1803,10 +1900,14 @@ def precision(labels, predictions, weights=None,
return p, update_op
-def precision_at_thresholds(labels, predictions, thresholds,
+@tf_export('metrics.precision_at_thresholds')
+def precision_at_thresholds(labels,
+ predictions,
+ thresholds,
weights=None,
metrics_collections=None,
- updates_collections=None, name=None):
+ updates_collections=None,
+ name=None):
"""Computes precision values for different `thresholds` on `predictions`.
The `precision_at_thresholds` function creates four local variables,
@@ -1862,12 +1963,13 @@ def precision_at_thresholds(labels, predictions, thresholds,
# Avoid division by zero.
epsilon = 1e-7
+
def compute_precision(tp, fp, name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
prec = compute_precision(values['tp'], values['fp'], 'value')
- update_op = compute_precision(
- update_ops['tp'], update_ops['fp'], 'update_op')
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+ 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, prec)
@@ -1878,8 +1980,12 @@ def precision_at_thresholds(labels, predictions, thresholds,
return prec, update_op
-def recall(labels, predictions, weights=None,
- metrics_collections=None, updates_collections=None,
+@tf_export('metrics.recall')
+def recall(labels,
+ predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
name=None):
"""Computes the recall of the predictions with respect to the labels.
@@ -1926,30 +2032,36 @@ def recall(labels, predictions, weights=None,
raise RuntimeError('tf.metrics.recall is not supported is not '
'supported when eager execution is enabled.')
- with variable_scope.variable_scope(
- name, 'recall', (predictions, labels, weights)):
+ with variable_scope.variable_scope(name, 'recall',
+ (predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
true_p, true_positives_update_op = true_positives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
false_n, false_negatives_update_op = false_negatives(
- labels, predictions, weights, metrics_collections=None,
- updates_collections=None, name=None)
+ labels,
+ predictions,
+ weights,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None)
def compute_recall(true_p, false_n, name):
return array_ops.where(
math_ops.greater(true_p + false_n, 0),
- math_ops.div(true_p, true_p + false_n),
- 0,
- name)
+ math_ops.div(true_p, true_p + false_n), 0, name)
rec = compute_recall(true_p, false_n, 'value')
- update_op = compute_recall(
- true_positives_update_op, false_negatives_update_op, 'update_op')
+ update_op = compute_recall(true_positives_update_op,
+ false_negatives_update_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, rec)
@@ -1983,8 +2095,8 @@ def _select_class_id(ids, selected_id):
"""
ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
if isinstance(ids, sparse_tensor.SparseTensor):
- return sparse_ops.sparse_retain(
- ids, math_ops.equal(ids.values, selected_id))
+ return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
+ selected_id))
# TODO(ptucker): Make this more efficient, maybe add a sparse version of
# tf.equal and tf.reduce_any?
@@ -1992,12 +2104,13 @@ def _select_class_id(ids, selected_id):
# Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
ids_last_dim = array_ops.size(ids_shape) - 1
- filled_selected_id_shape = math_ops.reduced_shape(
- ids_shape, array_ops.reshape(ids_last_dim, [1]))
+ filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
+ array_ops.reshape(
+ ids_last_dim, [1]))
# Intersect `ids` with the selected ID.
- filled_selected_id = array_ops.fill(
- filled_selected_id_shape, math_ops.to_int64(selected_id))
+ filled_selected_id = array_ops.fill(filled_selected_id_shape,
+ math_ops.to_int64(selected_id))
result = sets.set_intersection(filled_selected_id, ids)
return sparse_tensor.SparseTensor(
indices=result.indices, values=result.values, dense_shape=ids_shape)
@@ -2057,15 +2170,15 @@ def _sparse_true_positive_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of true positive counts.
"""
- with ops.name_scope(
- name, 'true_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(
- labels, predictions_idx, class_id)
+ with ops.name_scope(name, 'true_positives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
+ class_id)
tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
tp = math_ops.to_double(tp)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, tp),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, tp),)):
weights = math_ops.to_double(weights)
tp = math_ops.multiply(tp, weights)
return tp
@@ -2109,11 +2222,12 @@ def _streaming_sparse_true_positive_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('true_positive', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
tp = _sparse_true_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
@@ -2150,18 +2264,16 @@ def _sparse_false_negative_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of false negative counts.
"""
- with ops.name_scope(
- None, 'false_negatives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
+ with ops.name_scope(None, 'false_negatives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
class_id)
- fn = sets.set_size(sets.set_difference(predictions_idx,
- labels,
- aminusb=False))
+ fn = sets.set_size(
+ sets.set_difference(predictions_idx, labels, aminusb=False))
fn = math_ops.to_double(fn)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, fn),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, fn),)):
weights = math_ops.to_double(weights)
fn = math_ops.multiply(fn, weights)
return fn
@@ -2205,11 +2317,12 @@ def _streaming_sparse_false_negative_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('false_negative', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
fn = _sparse_false_negative_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
@@ -2217,6 +2330,7 @@ def _streaming_sparse_false_negative_at_k(labels,
return var, state_ops.assign_add(var, batch_total_fn, name='update')
+@tf_export('metrics.recall_at_k')
def recall_at_k(labels,
predictions,
k,
@@ -2295,9 +2409,8 @@ def recall_at_k(labels,
raise RuntimeError('tf.metrics.recall_at_k is not '
'supported when eager execution is enabled.')
- with ops.name_scope(
- name, _at_k_name('recall', k, class_id=class_id),
- (predictions, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
+ (predictions, labels, weights)) as scope:
_, top_k_idx = nn.top_k(predictions, k)
return recall_at_top_k(
labels=labels,
@@ -2310,6 +2423,7 @@ def recall_at_k(labels,
name=scope)
+@tf_export('metrics.recall_at_top_k')
def recall_at_top_k(labels,
predictions_idx,
k=None,
@@ -2363,16 +2477,21 @@ def recall_at_top_k(labels,
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
- with ops.name_scope(name,
- _at_k_name('recall', k, class_id=class_id),
+ with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
(predictions_idx, labels, weights)) as scope:
labels = _maybe_expand_labels(labels, predictions_idx)
top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
fn, fn_update = _streaming_sparse_false_negative_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
@@ -2385,9 +2504,14 @@ def recall_at_top_k(labels,
return metric, update
-def recall_at_thresholds(labels, predictions, thresholds,
- weights=None, metrics_collections=None,
- updates_collections=None, name=None):
+@tf_export('metrics.recall_at_thresholds')
+def recall_at_thresholds(labels,
+ predictions,
+ thresholds,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes various recall values for different `thresholds` on `predictions`.
The `recall_at_thresholds` function creates four local variables,
@@ -2441,6 +2565,7 @@ def recall_at_thresholds(labels, predictions, thresholds,
# Avoid division by zero.
epsilon = 1e-7
+
def compute_recall(tp, fn, name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
@@ -2456,7 +2581,10 @@ def recall_at_thresholds(labels, predictions, thresholds,
return rec, update_op
-def root_mean_squared_error(labels, predictions, weights=None,
+@tf_export('metrics.root_mean_squared_error')
+def root_mean_squared_error(labels,
+ predictions,
+ weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
@@ -2509,9 +2637,9 @@ def root_mean_squared_error(labels, predictions, weights=None,
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
- mse, update_mse_op = mean_squared_error(
- labels, predictions, weights, None, None,
- name or 'root_mean_squared_error')
+ mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
+ None, name or
+ 'root_mean_squared_error')
rmse = math_ops.sqrt(mse)
update_rmse_op = math_ops.sqrt(update_mse_op)
@@ -2525,9 +2653,15 @@ def root_mean_squared_error(labels, predictions, weights=None,
return rmse, update_rmse_op
-def sensitivity_at_specificity(
- labels, predictions, specificity, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None, name=None):
+@tf_export('metrics.sensitivity_at_specificity')
+def sensitivity_at_specificity(labels,
+ predictions,
+ specificity,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the specificity at a given sensitivity.
The `sensitivity_at_specificity` function creates four local
@@ -2588,8 +2722,9 @@ def sensitivity_at_specificity(
with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -2601,8 +2736,7 @@ def sensitivity_at_specificity(
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the sensitivity:
- return math_ops.div(tp[tf_index],
- tp[tf_index] + fn[tf_index] + kepsilon,
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
sensitivity = compute_sensitivity_at_specificity(
@@ -2641,8 +2775,8 @@ def _expand_and_tile(tensor, multiple, dim=0, name=None):
"""
if multiple < 1:
raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
- with ops.name_scope(
- name, 'expand_and_tile', (tensor, multiple, dim)) as scope:
+ with ops.name_scope(name, 'expand_and_tile',
+ (tensor, multiple, dim)) as scope:
# Sparse.
tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
if isinstance(tensor, sparse_tensor.SparseTensor):
@@ -2742,8 +2876,8 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
Raises:
ValueError: if the last dimension of predictions_idx is not set.
"""
- with ops.name_scope(
- None, 'average_precision', (predictions_idx, labels)) as scope:
+ with ops.name_scope(None, 'average_precision',
+ (predictions_idx, labels)) as scope:
predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
if predictions_idx.get_shape().ndims == 0:
raise ValueError('The rank of predictions_idx must be at least 1.')
@@ -2780,10 +2914,12 @@ def _sparse_average_precision_at_top_k(labels, predictions_idx):
retrieved_per_k = math_ops.cumsum(
array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
precision_per_k = math_ops.div(
- math_ops.to_double(tp_per_k), math_ops.to_double(retrieved_per_k),
+ math_ops.to_double(tp_per_k),
+ math_ops.to_double(retrieved_per_k),
name='precision_per_k')
relevant_precision_per_k = math_ops.multiply(
- precision_per_k, math_ops.to_double(relevant_per_k),
+ precision_per_k,
+ math_ops.to_double(relevant_per_k),
name='relevant_precision_per_k')
# Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
@@ -2887,6 +3023,7 @@ def _streaming_sparse_average_precision_at_top_k(labels,
return mean_average_precision, update
+@tf_export('metrics.sparse_average_precision_at_k')
@deprecated(None, 'Use average_precision_at_k instead')
def sparse_average_precision_at_k(labels,
predictions,
@@ -2906,6 +3043,7 @@ def sparse_average_precision_at_k(labels,
name=name)
+@tf_export('metrics.average_precision_at_k')
def average_precision_at_k(labels,
predictions,
k,
@@ -2971,9 +3109,8 @@ def average_precision_at_k(labels,
if k < 1:
raise ValueError('Invalid k=%s.' % k)
- with ops.name_scope(
- name, _at_k_name('average_precision', k),
- (predictions, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('average_precision', k),
+ (predictions, labels, weights)) as scope:
# Calculate top k indices to produce [D1, ... DN, k] tensor.
_, predictions_idx = nn.top_k(predictions, k)
return _streaming_sparse_average_precision_at_top_k(
@@ -3014,17 +3151,16 @@ def _sparse_false_positive_at_k(labels,
Returns:
A [D1, ... DN] `Tensor` of false positive counts.
"""
- with ops.name_scope(
- None, 'false_positives', (predictions_idx, labels, weights)):
- labels, predictions_idx = _maybe_select_class_id(labels,
- predictions_idx,
+ with ops.name_scope(None, 'false_positives',
+ (predictions_idx, labels, weights)):
+ labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
class_id)
- fp = sets.set_size(sets.set_difference(
- predictions_idx, labels, aminusb=True))
+ fp = sets.set_size(
+ sets.set_difference(predictions_idx, labels, aminusb=True))
fp = math_ops.to_double(fp)
if weights is not None:
- with ops.control_dependencies((
- weights_broadcast_ops.assert_broadcastable(weights, fp),)):
+ with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
+ weights, fp),)):
weights = math_ops.to_double(weights)
fp = math_ops.multiply(fp, weights)
return fp
@@ -3068,11 +3204,12 @@ def _streaming_sparse_false_positive_at_k(labels,
Raises:
ValueError: If `weights` is not `None` and has an incompatible shape.
"""
- with ops.name_scope(
- name, _at_k_name('false_positive', k, class_id=class_id),
- (predictions_idx, labels, weights)) as scope:
+ with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
+ (predictions_idx, labels, weights)) as scope:
fp = _sparse_false_positive_at_k(
- predictions_idx=predictions_idx, labels=labels, class_id=class_id,
+ predictions_idx=predictions_idx,
+ labels=labels,
+ class_id=class_id,
weights=weights)
batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
@@ -3080,6 +3217,7 @@ def _streaming_sparse_false_positive_at_k(labels,
return var, state_ops.assign_add(var, batch_total_fp, name='update')
+@tf_export('metrics.precision_at_top_k')
def precision_at_top_k(labels,
predictions_idx,
k=None,
@@ -3143,10 +3281,16 @@ def precision_at_top_k(labels,
labels = _maybe_expand_labels(labels, predictions_idx)
top_k_idx = math_ops.to_int64(predictions_idx)
tp, tp_update = _streaming_sparse_true_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
fp, fp_update = _streaming_sparse_false_positive_at_k(
- predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
+ predictions_idx=top_k_idx,
+ labels=labels,
+ k=k,
+ class_id=class_id,
weights=weights)
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
@@ -3159,6 +3303,7 @@ def precision_at_top_k(labels,
return metric, update
+@tf_export('metrics.sparse_precision_at_k')
@deprecated(None, 'Use precision_at_k instead')
def sparse_precision_at_k(labels,
predictions,
@@ -3180,6 +3325,7 @@ def sparse_precision_at_k(labels,
name=name)
+@tf_export('metrics.precision_at_k')
def precision_at_k(labels,
predictions,
k,
@@ -3273,9 +3419,15 @@ def precision_at_k(labels,
name=scope)
-def specificity_at_sensitivity(
- labels, predictions, sensitivity, weights=None, num_thresholds=200,
- metrics_collections=None, updates_collections=None, name=None):
+@tf_export('metrics.specificity_at_sensitivity')
+def specificity_at_sensitivity(labels,
+ predictions,
+ sensitivity,
+ weights=None,
+ num_thresholds=200,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Computes the specificity at a given sensitivity.
The `specificity_at_sensitivity` function creates four local
@@ -3336,8 +3488,9 @@ def specificity_at_sensitivity(
with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
(predictions, labels, weights)):
kepsilon = 1e-7 # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds-2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
values, update_ops = _confusion_matrix_at_thresholds(
@@ -3369,8 +3522,7 @@ def specificity_at_sensitivity(
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the specificity:
- return math_ops.div(tn[tf_index],
- tn[tf_index] + fp[tf_index] + kepsilon,
+ return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
specificity = compute_specificity_at_sensitivity(
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index fc013b565b..eebfb17085 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -21,10 +21,8 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -40,15 +38,6 @@ from tensorflow.python.platform import test
@test_util.with_c_api
class BatchNormalizationTest(test.TestCase):
- def SetProducerVersion(self, graph, producer_version):
- # The C API doesn't expose altering GraphDefVersions. We can indirectly set
- # it via import_graph_def though.
- graph_def = graph_pb2.GraphDef()
- graph_def.versions.producer = producer_version
- with graph.as_default():
- importer.import_graph_def(graph_def)
- assert graph.graph_def_versions.producer, producer_version
-
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
y = (x - m) / np.sqrt(v + epsilon)
@@ -65,7 +54,7 @@ class BatchNormalizationTest(test.TestCase):
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
"""Original implementation."""
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
return gen_nn_ops._batch_norm_with_global_normalization(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
# pylint: enable=protected-access
@@ -233,7 +222,7 @@ class BatchNormalizationTest(test.TestCase):
epsilon = 0.001
for scale_after_normalization in [True, False]:
# _batch_norm_with_global_normalization_grad is deprecated in v9
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
grad = gen_nn_ops._batch_norm_with_global_normalization_grad(
x, m, v, gamma, backprop, epsilon, scale_after_normalization)
dx, dm, dv, db, dg = grad
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 8cd535aa0b..5e6cafd6aa 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -89,52 +89,63 @@ def _Conv2DBackpropFilterGrad(op, grad):
@ops.RegisterGradient("Conv3D")
def _Conv3DGrad(op, grad):
data_format = op.get_attr("data_format")
- return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]),
- op.inputs[1],
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- nn_ops.conv3d_backprop_filter_v2(op.inputs[0],
- array_ops.shape(op.inputs[1]),
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ nn_ops.conv3d_backprop_input_v2(
+ array_ops.shape(op.inputs[0]),
+ op.inputs[1],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format),
+ nn_ops.conv3d_backprop_filter_v2(
+ op.inputs[0],
+ array_ops.shape(op.inputs[1]),
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("Conv3DBackpropInputV2")
def _Conv3DBackpropInputGrad(op, grad):
data_format = op.get_attr("data_format")
- return [None,
- nn_ops.conv3d_backprop_filter_v2(grad,
- array_ops.shape(op.inputs[1]),
- op.inputs[2],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- nn_ops.conv3d(grad,
- op.inputs[1],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ None,
+ nn_ops.conv3d_backprop_filter_v2(
+ grad,
+ array_ops.shape(op.inputs[1]),
+ op.inputs[2],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format),
+ nn_ops.conv3d(
+ grad,
+ op.inputs[1],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("Conv3DBackpropFilterV2")
def _Conv3DBackpropFilterGrad(op, grad):
data_format = op.get_attr("data_format")
- return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]),
- grad,
- op.inputs[2],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- None,
- nn_ops.conv3d(op.inputs[0],
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ nn_ops.conv3d_backprop_input_v2(
+ array_ops.shape(op.inputs[0]),
+ grad,
+ op.inputs[2],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format), None,
+ nn_ops.conv3d(
+ op.inputs[0],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("AvgPool3D")
@@ -150,12 +161,13 @@ def _AvgPool3DGrad(op, grad):
@ops.RegisterGradient("AvgPool3DGrad")
def _AvgPool3DGradGrad(op, grad):
- return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops.avg_pool3d(
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ return (array_ops.stop_gradient(op.inputs[0]),
+ gen_nn_ops.avg_pool3d(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPool3D")
@@ -173,9 +185,9 @@ def _MaxPool3DGrad(op, grad):
@ops.RegisterGradient("MaxPool3DGrad")
def _MaxPool3DGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool3d_grad_grad(
op.inputs[0],
op.inputs[1],
@@ -189,9 +201,9 @@ def _MaxPool3DGradGrad(op, grad):
@ops.RegisterGradient("MaxPool3DGradGrad")
def _MaxPool3DGradGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool3d_grad(
op.inputs[0],
op.inputs[1],
@@ -246,7 +258,7 @@ def _LogSoftmaxGrad(op, grad):
The gradients w.r.t. the input.
"""
softmax = math_ops.exp(op.outputs[0])
- return grad - math_ops.reduce_sum(grad, 1, keep_dims=True) * softmax
+ return grad - math_ops.reduce_sum(grad, 1, keepdims=True) * softmax
@ops.RegisterGradient("BiasAdd")
@@ -272,8 +284,9 @@ def _BiasAddGrad(op, received_grad):
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
- return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad,
- data_format=data_format))
+ return (received_grad,
+ gen_nn_ops.bias_add_grad(
+ out_backprop=received_grad, data_format=data_format))
@ops.RegisterGradient("BiasAddGrad")
@@ -346,10 +359,9 @@ def _ReluGrad(op, grad):
def _EluGradGrad(op, grad):
elu_x = op.inputs[1]
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
- array_ops.where(elu_x < 0,
- grad * op.inputs[0],
- array_ops.zeros(shape=array_ops.shape(elu_x),
- dtype=elu_x.dtype)))
+ array_ops.where(elu_x < 0, grad * op.inputs[0],
+ array_ops.zeros(
+ shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
@ops.RegisterGradient("SeluGrad")
@@ -357,9 +369,11 @@ def _SeluGradGrad(op, grad):
x = op.inputs[1]
scale_alpha = 1.7580993408473768599402175208123
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
- array_ops.where(
- x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha),
- array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
+ array_ops.where(x < 0.,
+ gen_nn_ops._elu_grad(grad,
+ op.outputs[0] + scale_alpha),
+ array_ops.zeros(
+ shape=array_ops.shape(x), dtype=x.dtype)))
@ops.RegisterGradient("Relu6")
@@ -370,8 +384,8 @@ def _Relu6Grad(op, grad):
@ops.RegisterGradient("Relu6Grad")
def _Relu6GradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._relu6_grad(grad, x), array_ops.zeros(
- shape=array_ops.shape(x), dtype=x.dtype))
+ return (gen_nn_ops._relu6_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
@ops.RegisterGradient("Elu")
@@ -410,8 +424,8 @@ def _SoftsignGrad(op, grad):
@ops.RegisterGradient("ReluGrad")
def _ReluGradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._relu_grad(grad, x), array_ops.zeros(
- shape=array_ops.shape(x), dtype=x.dtype))
+ return (gen_nn_ops._relu_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
def _BroadcastMul(vec, mat):
@@ -455,8 +469,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
- math_ops.matmul(grad_grad[:, None, :],
- softmax[:, :, None]), axis=1)) * softmax)
+ math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) *
+ softmax)
return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
@@ -473,7 +487,8 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
- op.outputs[1], message="Currently there is no way to take the second "
+ op.outputs[1],
+ message="Currently there is no way to take the second "
"derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
"implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
@@ -531,14 +546,16 @@ def _DepthwiseConv2dNativeGrad(op, grad):
@ops.RegisterGradient("Dilation2D")
def _Dilation2DGrad(op, grad):
- return [nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
- op.get_attr("strides"),
- op.get_attr("rates"),
- op.get_attr("padding")),
- nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
- op.get_attr("strides"),
- op.get_attr("rates"),
- op.get_attr("padding"))]
+ return [
+ nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding")),
+ nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding"))
+ ]
@ops.RegisterGradient("LRN")
@@ -547,8 +564,10 @@ def _LRNGrad(op, grad):
bias = op.get_attr("bias")
alpha = op.get_attr("alpha")
beta = op.get_attr("beta")
- return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius,
- bias, alpha, beta)]
+ return [
+ gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius,
+ bias, alpha, beta)
+ ]
@ops.RegisterGradient("AvgPool")
@@ -564,54 +583,58 @@ def _AvgPoolGrad(op, grad):
@ops.RegisterGradient("AvgPoolGrad")
def _AvgPoolGradGrad(op, grad):
- return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops._avg_pool(
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ return (array_ops.stop_gradient(op.inputs[0]),
+ gen_nn_ops._avg_pool(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPool")
def _MaxPoolGrad(op, grad):
- return gen_nn_ops._max_pool_grad(op.inputs[0],
- op.outputs[0],
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format"))
+ return gen_nn_ops._max_pool_grad(
+ op.inputs[0],
+ op.outputs[0],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format"))
@ops.RegisterGradient("MaxPoolV2")
def _MaxPoolGradV2(op, grad):
ksize = op.inputs[1]
strides = op.inputs[2]
- return gen_nn_ops.max_pool_grad_v2(op.inputs[0],
- op.outputs[0],
- grad,
- ksize,
- strides,
- padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format")), None, None
+ return gen_nn_ops.max_pool_grad_v2(
+ op.inputs[0],
+ op.outputs[0],
+ grad,
+ ksize,
+ strides,
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")), None, None
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
- return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
- grad,
- op.outputs[1],
- op.get_attr("ksize"),
- op.get_attr("strides"),
- padding=op.get_attr("padding"))
+ return gen_nn_ops._max_pool_grad_with_argmax(
+ op.inputs[0],
+ grad,
+ op.outputs[1],
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool_grad_grad(
op.inputs[0],
op.inputs[1],
@@ -627,9 +650,9 @@ def _MaxPoolGradGradV2(op, grad):
ksize = op.inputs[3]
strides = op.inputs[4]
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops.max_pool_grad_grad_v2(
op.inputs[0],
op.inputs[1],
@@ -643,9 +666,9 @@ def _MaxPoolGradGradV2(op, grad):
@ops.RegisterGradient("MaxPoolGradGrad")
def _MaxPoolGradGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool_grad(
op.inputs[0],
op.inputs[1],
@@ -674,10 +697,9 @@ def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
Input backprop for FractionalMaxPool op.
"""
# pylint: disable=protected-access
- return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0],
- grad_0, op.outputs[1],
- op.outputs[2],
- op.get_attr("overlapping"))
+ return gen_nn_ops._fractional_max_pool_grad(
+ op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
+ op.get_attr("overlapping"))
@ops.RegisterGradient("FractionalAvgPool")
@@ -761,8 +783,9 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad):
epsilon = op.get_attr("epsilon")
data_format = op.get_attr("data_format")
is_training = op.get_attr("is_training")
- grad_fun = (gen_nn_ops.fused_batch_norm_grad_v2 if use_v2
- else gen_nn_ops.fused_batch_norm_grad)
+ grad_fun = (
+ gen_nn_ops.fused_batch_norm_grad_v2
+ if use_v2 else gen_nn_ops.fused_batch_norm_grad)
if is_training:
return grad_fun(
grad_y,
@@ -786,7 +809,7 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad):
pop_mean,
pop_var,
epsilon=epsilon,
- data_format='NHWC',
+ data_format="NHWC",
is_training=is_training)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
@@ -803,18 +826,28 @@ def _FusedBatchNormV2Grad(op, *grad):
return _BaseFusedBatchNormGrad(op, True, *grad)
-def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training=True):
+def _BatchNormGrad(grad_y,
+ x,
+ scale,
+ pop_mean,
+ pop_var,
+ epsilon,
+ data_format,
+ is_training=True):
"""Returns the gradients for the 3 inputs of BatchNorm.
Args:
grad_y: A `Tensor` of 4 dimensions for gradient for y.
x: A `Tensor` of 4 dimensions for x.
scale: A `Tensor` of 1 dimension for scaling.
- pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False.
- pop_var: A `Tensor` of 1 dimension for the population variance. Only used when is_training=False.
+ pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
+ is_training=False.
+ pop_var: A `Tensor` of 1 dimension for the population variance. Only used
+ when is_training=False.
epsilon: A small float number added to the variance of x.
data_format: The data format for input. Either b"NHWC" or b"NCHW".
- is_training: A bool value to indicate the operation is for training (default)
+ is_training: A bool value to indicate the operation is for training
+ (default)
or inference.
Returns:
@@ -900,7 +933,7 @@ def _FusedBatchNormGradGrad(op, *grad):
grad_grad_scale = grad[1]
grad_grad_offset = grad[2]
grad_x, grad_scale, grad_offset = _BatchNormGrad(
- grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
+ grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
grad_grad_y, grad_x, grad_scale = gradients_impl.gradients(
[grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
@@ -954,14 +987,15 @@ def _TopKGrad(op, grad, _):
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
- return [array_ops.reshape(
- sparse_ops.sparse_to_dense(ind,
- array_ops.reshape(
- math_ops.reduce_prod(in_shape), [1]),
- array_ops.reshape(grad, [-1]),
- validate_indices=False),
- in_shape), array_ops.zeros(
- [], dtype=dtypes.int32)]
+ return [
+ array_ops.reshape(
+ sparse_ops.sparse_to_dense(
+ ind,
+ array_ops.reshape(math_ops.reduce_prod(in_shape), [1]),
+ array_ops.reshape(grad, [-1]),
+ validate_indices=False), in_shape),
+ array_ops.zeros([], dtype=dtypes.int32)
+ ]
@ops.RegisterGradient("NthElement")
@@ -983,11 +1017,9 @@ def _NthElementGrad(op, grad):
# dimension. If there are multiple elements then the gradient will be
# divided between them.
indicators = math_ops.cast(
- math_ops.equal(array_ops.expand_dims(output, -1), input),
- grad.dtype)
+ math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
grad = array_ops.expand_dims(grad, -1)
- num_selected = array_ops.expand_dims(
- math_ops.reduce_sum(indicators, -1), -1)
+ num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
return [math_ops.div(indicators, num_selected) * grad, None]
diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py
index f7541c0e89..aa7539ae9f 100644
--- a/tensorflow/python/ops/nn_grad_test.py
+++ b/tensorflow/python/ops/nn_grad_test.py
@@ -30,17 +30,20 @@ from tensorflow.python.platform import test
class Relu6OpTest(test.TestCase):
+
def testRelu6GradGrad(self):
- inputs = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]],
- dtype=dtypes.float32)
+ inputs = constant_op.constant(
+ [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
x_init_value = np.array([[-3.5, -1.5, 2, 4], [4.5, 7.5, 8.5, 11]])
r = nn_ops.relu6(inputs)
r_g = gradients_impl.gradients(r, inputs)[0]
with self.test_session():
error = gradient_checker.compute_gradient_error(
- inputs, inputs.get_shape().as_list(),
- r_g, r_g.get_shape().as_list(),
- x_init_value=x_init_value)
+ inputs,
+ inputs.get_shape().as_list(),
+ r_g,
+ r_g.get_shape().as_list(),
+ x_init_value=x_init_value)
self.assertLess(error, 1e-4)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index fd96f7b8fc..55fcd176d6 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -35,8 +35,10 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("nn.log_poisson_loss")
def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
"""Computes log Poisson loss given `log_input`.
@@ -101,6 +103,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
return result
+@tf_export("nn.sigmoid_cross_entropy_with_logits")
def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
_sentinel=None,
labels=None,
@@ -180,6 +183,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
name=name)
+@tf_export("nn.weighted_cross_entropy_with_logits")
def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
"""Computes a weighted cross entropy.
@@ -192,7 +196,13 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
targets * -log(sigmoid(logits)) +
(1 - targets) * -log(1 - sigmoid(logits))
- The argument `pos_weight` is used as a multiplier for the positive targets:
+ A value `pos_weights > 1` decreases the false negative count, hence increasing
+ the recall.
+ Conversely setting `pos_weights < 1` decreases the false positive count and
+ increases the precision.
+ This can be seen from the fact that `pos_weight` is introduced as a
+ multiplicative coefficient for the positive targets term
+ in the loss expression:
targets * -log(sigmoid(logits)) * pos_weight +
(1 - targets) * -log(1 - sigmoid(logits))
@@ -251,6 +261,7 @@ def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
name=name)
+@tf_export("nn.relu_layer")
def relu_layer(x, weights, biases, name=None):
"""Computes Relu(x * weight + biases).
@@ -297,6 +308,7 @@ def _swish_grad(features, grad):
shape_func=_swish_shape,
func_name="swish",
noinline=True)
+@tf_export("nn.swish")
def swish(features):
# pylint: disable=g-doc-args
"""Computes the Swish activation function: `x * sigmoid(x)`.
@@ -316,6 +328,7 @@ def swish(features):
return features * math_ops.sigmoid(features)
+@tf_export("nn.l2_normalize")
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
@@ -347,6 +360,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
return math_ops.multiply(x, x_inv_norm, name=name)
+@tf_export("nn.zero_fraction")
def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`.
@@ -374,6 +388,7 @@ def zero_fraction(value, name=None):
# pylint: disable=redefined-builtin
+@tf_export("nn.depthwise_conv2d")
def depthwise_conv2d(input,
filter,
strides,
@@ -450,6 +465,7 @@ def depthwise_conv2d(input,
# pylint: disable=redefined-builtin,line-too-long
+@tf_export("nn.separable_conv2d")
def separable_conv2d(input,
depthwise_filter,
pointwise_filter,
@@ -550,6 +566,7 @@ def separable_conv2d(input,
# pylint: enable=redefined-builtin,line-too-long
+@tf_export("nn.sufficient_statistics")
def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
@@ -599,6 +616,7 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None):
return counts, m_ss, v_ss, shift
+@tf_export("nn.normalize_moments")
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
@@ -630,9 +648,13 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
return (mean, variance)
-def moments(x, axes,
- shift=None, # pylint: disable=unused-argument
- name=None, keep_dims=False):
+@tf_export("nn.moments")
+def moments(
+ x,
+ axes,
+ shift=None, # pylint: disable=unused-argument
+ name=None,
+ keep_dims=False):
"""Calculate the mean and variance of `x`.
The mean and variance are calculated by aggregating the contents of `x`
@@ -676,12 +698,13 @@ def moments(x, axes,
mean = array_ops.squeeze(mean, axes)
variance = array_ops.squeeze(variance, axes)
if x.dtype == dtypes.float16:
- return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
- variance, dtypes.float16))
+ return (math_ops.cast(mean, dtypes.float16),
+ math_ops.cast(variance, dtypes.float16))
else:
return (mean, variance)
+@tf_export("nn.weighted_moments")
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
"""Returns the frequency-weighted mean and variance of `x`.
@@ -753,6 +776,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
return weighted_mean, weighted_variance
+@tf_export("nn.batch_normalization")
def batch_normalization(x,
mean,
variance,
@@ -806,10 +830,11 @@ def batch_normalization(x,
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
inv *= scale
- return x * inv + (offset - mean * inv
- if offset is not None else -mean * inv)
+ return x * inv + (
+ offset - mean * inv if offset is not None else -mean * inv)
+@tf_export("nn.fused_batch_norm")
def fused_batch_norm(
x,
scale,
@@ -882,6 +907,7 @@ def fused_batch_norm(
return y, batch_mean, batch_var
+@tf_export("nn.batch_norm_with_global_normalization")
def batch_norm_with_global_normalization(t,
m,
v,
@@ -943,7 +969,8 @@ def _compute_sampled_logits(weights,
subtract_log_q=True,
remove_accidental_hits=False,
partition_strategy="mod",
- name=None):
+ name=None,
+ seed=None):
"""Helper function for nce_loss and sampled_softmax_loss functions.
Computes sampled output training logits and labels suitable for implementing
@@ -981,6 +1008,8 @@ def _compute_sampled_logits(weights,
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
+ seed: random seed for candidate sampling. Default to None, which doesn't set
+ the op-level random seed for candidate sampling.
Returns:
out_logits: `Tensor` object with shape
`[batch_size, num_true + num_sampled]`, for passing to either
@@ -1010,7 +1039,8 @@ def _compute_sampled_logits(weights,
num_true=num_true,
num_sampled=num_sampled,
unique=True,
- range_max=num_classes)
+ range_max=num_classes,
+ seed=seed)
# NOTE: pylint cannot tell that 'sampled_values' is a sequence
# pylint: disable=unpacking-non-sequence
sampled, true_expected_count, sampled_expected_count = (
@@ -1109,6 +1139,7 @@ def _compute_sampled_logits(weights,
return out_logits, out_labels
+@tf_export("nn.nce_loss")
def nce_loss(weights,
biases,
labels,
@@ -1217,6 +1248,7 @@ def nce_loss(weights,
return _sum_rows(sampled_losses)
+@tf_export("nn.sampled_softmax_loss")
def sampled_softmax_loss(weights,
biases,
labels,
@@ -1227,7 +1259,8 @@ def sampled_softmax_loss(weights,
sampled_values=None,
remove_accidental_hits=True,
partition_strategy="mod",
- name="sampled_softmax_loss"):
+ name="sampled_softmax_loss",
+ seed=None):
"""Computes and returns the sampled softmax training loss.
This is a faster way to train a softmax classifier over a huge number of
@@ -1288,6 +1321,8 @@ def sampled_softmax_loss(weights,
if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: A name for the operation (optional).
+ seed: random seed for candidate sampling. Default to None, which doesn't set
+ the op-level random seed for candidate sampling.
Returns:
A `batch_size` 1-D tensor of per-example sampled softmax losses.
@@ -1305,7 +1340,8 @@ def sampled_softmax_loss(weights,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
- name=name)
+ name=name,
+ seed=seed)
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
# sampled_losses is a [batch_size] tensor.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 865e459e90..32b14f86b5 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# Aliases for some automatically-generated names.
@@ -190,6 +191,7 @@ class _NonAtrousConvolution(object):
name=self.name)
+@tf_export("nn.with_space_to_batch")
def with_space_to_batch(
input, # pylint: disable=redefined-builtin
dilation_rate,
@@ -633,6 +635,7 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate):
return strides, dilation_rate
+@tf_export("nn.convolution")
def convolution(input, filter, # pylint: disable=redefined-builtin
padding, strides=None, dilation_rate=None,
name=None, data_format=None):
@@ -848,6 +851,7 @@ class Convolution(object):
return self.conv_op(inp, filter)
+@tf_export("nn.pool")
def pool(input, # pylint: disable=redefined-builtin
window_shape,
pooling_type,
@@ -1015,6 +1019,7 @@ def pool(input, # pylint: disable=redefined-builtin
filter_shape=window_shape)
+@tf_export("nn.atrous_conv2d")
def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).
@@ -1150,6 +1155,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
name=name)
+@tf_export("nn.conv2d_transpose")
def conv2d_transpose(value,
filter, # pylint: disable=redefined-builtin
output_shape,
@@ -1225,6 +1231,7 @@ def conv2d_transpose(value,
name=name)
+@tf_export("nn.atrous_conv2d_transpose")
def atrous_conv2d_transpose(value,
filters,
output_shape,
@@ -1371,6 +1378,7 @@ def atrous_conv2d_transpose(value,
block_size=rate)
+@tf_export("nn.conv3d_transpose")
def conv3d_transpose(value,
filter, # pylint: disable=redefined-builtin
output_shape,
@@ -1444,6 +1452,7 @@ def conv3d_transpose(value,
# pylint: disable=protected-access
+@tf_export("nn.bias_add")
def bias_add(value, bias, data_format=None, name=None):
"""Adds `bias` to `value`.
@@ -1498,6 +1507,7 @@ def bias_add_v1(value, bias, name=None):
return gen_nn_ops._bias_add_v1(value, bias, name=name)
+@tf_export("nn.crelu")
def crelu(features, name=None, axis=-1):
"""Computes Concatenated ReLU.
@@ -1521,6 +1531,7 @@ def crelu(features, name=None, axis=-1):
return gen_nn_ops.relu(c)
+@tf_export("nn.relu6")
def relu6(features, name=None):
"""Computes Rectified Linear 6: `min(max(features, 0), 6)`.
Source: [Convolutional Deep Belief Networks on CIFAR-10. A. Krizhevsky](http://www.cs.utoronto.ca/~kriz/conv-cifar10-aug2010.pdf)
@@ -1538,6 +1549,7 @@ def relu6(features, name=None):
return gen_nn_ops._relu6(features, name=name)
+@tf_export("nn.leaky_relu")
def leaky_relu(features, alpha=0.2, name=None):
"""Compute the Leaky ReLU activation function.
@@ -1546,7 +1558,8 @@ def leaky_relu(features, alpha=0.2, name=None):
http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf
Args:
- features: A `Tensor` representing preactivation values.
+ features: A `Tensor` representing preactivation values. Must be one of
+ the following types: `float16`, `float32`, `float64`, `int32`, `int64`.
alpha: Slope of the activation function at x < 0.
name: A name for the operation (optional).
@@ -1555,7 +1568,9 @@ def leaky_relu(features, alpha=0.2, name=None):
"""
with ops.name_scope(name, "LeakyRelu", [features, alpha]):
features = ops.convert_to_tensor(features, name="features")
- alpha = ops.convert_to_tensor(alpha, name="alpha")
+ if features.dtype.is_integer:
+ features = math_ops.to_float(features)
+ alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features)
@@ -1661,6 +1676,7 @@ def _softmax(logits, compute_op, dim=-1, name=None):
return output
+@tf_export("nn.softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
@@ -1690,6 +1706,7 @@ def softmax(logits, axis=None, name=None, dim=None):
return _softmax(logits, gen_nn_ops._softmax, axis, name)
+@tf_export("nn.log_softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
@@ -1728,6 +1745,7 @@ def _ensure_xent_args(name, sentinel, labels, logits):
raise ValueError("Both labels and logits must be provided.")
+@tf_export("nn.softmax_cross_entropy_with_logits_v2")
def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
dim=-1, name=None):
@@ -1842,6 +1860,7 @@ See tf.nn.softmax_cross_entropy_with_logits_v2.
"""
+@tf_export("nn.softmax_cross_entropy_with_logits")
@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
@@ -1898,6 +1917,7 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid
labels=labels, logits=logits, dim=dim, name=name)
+@tf_export("nn.sparse_softmax_cross_entropy_with_logits")
def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
labels=None, logits=None,
name=None):
@@ -1996,6 +2016,7 @@ def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=
return cost
+@tf_export("nn.avg_pool")
def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the average pooling on the input.
@@ -2028,6 +2049,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
name=name)
+@tf_export("nn.max_pool")
def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
"""Performs the max pooling on the input.
@@ -2099,6 +2121,7 @@ def _calc_bias_add_flops(graph, node):
return ops.OpStats("flops", input_count)
+@tf_export("nn.xw_plus_b")
def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name
"""Computes matmul(x, weights) + biases.
@@ -2145,6 +2168,7 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
return bias_add_v1(mm, biases, name=name)
+@tf_export("nn.dropout")
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
@@ -2209,6 +2233,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
return ret
+@tf_export("nn.top_k")
def top_k(input, k=1, sorted=True, name=None):
"""Finds values and indices of the `k` largest entries for the last dimension.
@@ -2266,6 +2291,7 @@ def nth_element(input, n, reverse=False, name=None):
return gen_nn_ops.nth_element(input, n, reverse=reverse, name=name)
+@tf_export("nn.conv1d")
@deprecation.deprecated_arg_values(
None, "`NCHW` for data_format is deprecated, use `NCW` instead",
warn_once=True, data_format="NCHW")
@@ -2300,7 +2326,7 @@ def conv1d(value, filters, stride, padding,
returned to the caller.
Args:
- value: A 3D `Tensor`. Must be of type `float32` or `float64`.
+ value: A 3D `Tensor`. Must be of type `float16` or `float32`.
filters: A 3D `Tensor`. Must have the same type as `input`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
@@ -2451,6 +2477,7 @@ def _calc_dilation2d_flops(graph, node):
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
+@tf_export("nn.erosion2d")
def erosion2d(value, kernel, strides, rates, padding, name=None):
"""Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors.
@@ -2508,6 +2535,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
name=name))
+@tf_export("nn.in_top_k")
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 66bc0803b7..5a45bdc1e5 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -131,8 +131,7 @@ class LogPoissonLossTest(test_lib.TestCase):
y_np = self._log_poisson_loss(x_np, z_np, compute_full_loss=False)
y_np_stirling = self._log_poisson_loss(x_np, z_np, compute_full_loss=True)
y_tf = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=False)
- y_tf_stirling = nn_impl.log_poisson_loss(
- z_np, x_np, compute_full_loss=True)
+ y_tf_stirling = nn_impl.log_poisson_loss(z_np, x_np, compute_full_loss=True)
y_tf_np = self.evaluate(y_tf)
y_tf_np_stirling = self.evaluate(y_tf_stirling)
eps = 1e-3
@@ -773,8 +772,8 @@ class ComputeSampledLogitsTest(test_lib.TestCase):
def _SoftmaxCrossEntropyWithLogits(logits, targets):
# logits, targets: float arrays of the same shape.
assert logits.shape == targets.shape
- stable_exp_logits = np.exp(logits - np.amax(
- logits, axis=1, keepdims=True))
+ stable_exp_logits = np.exp(
+ logits - np.amax(logits, axis=1, keepdims=True))
pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
@@ -865,8 +864,8 @@ class LeakyReluTest(test_lib.TestCase):
batch_size = 3
height, width = 4, 4
np.random.seed(1) # Make it reproducible.
- inputs = np.random.uniform(
- size=(batch_size, height, width, 3)).astype(np.float32)
+ inputs = np.random.uniform(size=(batch_size, height, width, 3)).astype(
+ np.float32)
inputs = constant_op.constant(inputs)
outputs = nn_ops.leaky_relu(inputs)
@@ -878,11 +877,14 @@ class LeakyReluTest(test_lib.TestCase):
self.assertAllClose(inputs, outputs)
def testValues(self):
- np_values = np.array([-1.0, 0.0, 0.5, 1.0, 2.0], dtype=np.float32)
- outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
- with self.test_session() as sess:
- outputs = sess.run(outputs)
- self.assertAllClose(outputs, [-0.2, 0.0, 0.5, 1.0, 2.0])
+ for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype)
+ outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
+ with self.test_session() as sess:
+ outputs = sess.run(outputs)
+ tol = 2e-3 if dtype == np.float16 else 1e-6
+ self.assertAllClose(
+ outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
class SwishTest(test_lib.TestCase):
@@ -913,7 +915,10 @@ class SwishTest(test_lib.TestCase):
class MomentsTest(test_lib.TestCase):
- def doOutputTest(self, input_shape, moments_axes, tol=1e-4,
+ def doOutputTest(self,
+ input_shape,
+ moments_axes,
+ tol=1e-4,
check_gradients=False):
for mu in [0.0, 1.0, 1e3]:
for sigma in [1.0, 0.1]:
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index f3558fda9c..b4ce1cbf25 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -24,8 +24,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("verify_tensor_all_finite")
def verify_tensor_all_finite(t, msg, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
@@ -45,6 +47,7 @@ def verify_tensor_all_finite(t, msg, name=None):
return out
+@tf_export("add_check_numerics_ops")
def add_check_numerics_ops():
"""Connect a `check_numerics` to every floating point tensor.
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 7b6f08f68c..b0315ceee2 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.platform import tf_logging
+from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("DecodeRaw")
@@ -44,6 +45,7 @@ ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
+@tf_export("VarLenFeature")
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
@@ -53,6 +55,7 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
pass
+@tf_export("SparseFeature")
class SparseFeature(
collections.namedtuple(
"SparseFeature",
@@ -127,6 +130,7 @@ class SparseFeature(
cls, index_key, value_key, dtype, size, already_sorted)
+@tf_export("FixedLenFeature")
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
@@ -146,6 +150,7 @@ class FixedLenFeature(collections.namedtuple(
cls, shape, dtype, default_value)
+@tf_export("FixedLenSequenceFeature")
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
@@ -355,6 +360,7 @@ def _prepend_none_dimension(features):
return features
+@tf_export("parse_example")
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@@ -715,6 +721,7 @@ def _parse_example_raw(serialized,
return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+@tf_export("parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@@ -850,6 +857,7 @@ def _parse_single_example_raw(serialized,
return outputs
+@tf_export("parse_single_sequence_example")
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
example_name=None, name=None):
@@ -1171,6 +1179,7 @@ def _parse_single_sequence_example_raw(serialized,
# Swap `name` and `na_value` for backward compatibility.
+@tf_export("decode_csv")
def decode_csv(records, record_defaults, field_delim=",",
use_quote_delim=True, name=None, na_value=""):
# pylint: disable=protected-access
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index edcc0e1d7c..174cabdf80 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -58,6 +58,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
__all__ = [
"create_partitioned_variables",
@@ -67,6 +68,7 @@ __all__ = [
]
+@tf_export("variable_axis_size_partitioner")
def variable_axis_size_partitioner(
max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
"""Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
@@ -151,6 +153,7 @@ def variable_axis_size_partitioner(
return _partitioner
+@tf_export("min_max_variable_partitioner")
def min_max_variable_partitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
bytes_per_string_element=16):
@@ -214,6 +217,7 @@ def min_max_variable_partitioner(max_partitions=1, axis=0,
return _partitioner
+@tf_export("fixed_size_partitioner")
def fixed_size_partitioner(num_shards, axis=0):
"""Partitioner to specify a fixed number of shards along given axis.
@@ -232,6 +236,7 @@ def fixed_size_partitioner(num_shards, axis=0):
return _partitioner
+@tf_export("create_partitioned_variables")
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
trainable=True, collections=None, name=None, reuse=None):
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index a2264a7bdf..2c86358d21 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_random_ops import *
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -43,6 +44,7 @@ def _ShapeTensor(shape):
# pylint: disable=protected-access
+@tf_export("random_normal")
def random_normal(shape,
mean=0.0,
stddev=1.0,
@@ -135,6 +137,7 @@ def parameterized_truncated_normal(shape,
return rnd
+@tf_export("truncated_normal")
def truncated_normal(shape,
mean=0.0,
stddev=1.0,
@@ -179,6 +182,7 @@ ops.NotDifferentiable("ParameterizedTruncatedNormal")
ops.NotDifferentiable("TruncatedNormal")
+@tf_export("random_uniform")
def random_uniform(shape,
minval=0,
maxval=None,
@@ -244,6 +248,7 @@ def random_uniform(shape,
ops.NotDifferentiable("RandomUniform")
+@tf_export("random_shuffle")
def random_shuffle(value, seed=None, name=None):
"""Randomly shuffles a tensor along its first dimension.
@@ -274,6 +279,7 @@ def random_shuffle(value, seed=None, name=None):
value, seed=seed1, seed2=seed2, name=name)
+@tf_export("random_crop")
def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size.
@@ -316,6 +322,7 @@ def random_crop(value, size, seed=None, name=None):
return array_ops.slice(value, offset, size, name=name)
+@tf_export("multinomial")
def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
"""Draws samples from a multinomial distribution.
@@ -351,6 +358,7 @@ def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
ops.NotDifferentiable("Multinomial")
+@tf_export("random_gamma")
def random_gamma(shape,
alpha,
beta=None,
@@ -418,6 +426,7 @@ def random_gamma(shape,
ops.NotDifferentiable("RandomGamma")
+@tf_export("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s).
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 879c206313..bdf41cd75d 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -348,11 +348,11 @@ class ResourceVariable(variables.Variable):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
self._save_slice_info = None
- self._in_graph_mode = context.in_graph_mode()
# Save the graph's container prefix for error checking. Reading the value of
# the ResourceVariable from another Graph in Eager mode is an error.
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
- with ops.control_dependencies(None):
+ with ops.init_scope():
+ self._in_graph_mode = context.in_graph_mode()
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access
@@ -835,25 +835,45 @@ class ResourceVariable(variables.Variable):
return self.value()
def __iadd__(self, unused_other):
- raise RuntimeError("Variable += value not supported.")
+ raise RuntimeError("Variable += value not supported. Use "
+ "variable.assign_add(value) to modify the variable "
+ "value and variable = variable + value to get a new "
+ "Tensor object.")
def __isub__(self, unused_other):
- raise RuntimeError("Variable -= value not supported.")
+ raise RuntimeError("Variable -= value not supported. Use "
+ "variable.assign_sub(value) to modify the variable "
+ "value and variable = variable - value to get a new "
+ "Tensor object.")
def __imul__(self, unused_other):
- raise RuntimeError("Variable *= value not supported.")
+ raise RuntimeError("Variable *= value not supported. Use "
+ "variable.assign_mul(value) to modify the variable "
+ "value and variable = variable * value to get a new "
+ "Tensor object.")
def __idiv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __itruediv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __irealdiv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __ipow__(self, unused_other):
- raise RuntimeError("Variable **= value not supported.")
+ raise RuntimeError("Variable **= value not supported. Use "
+ "value and variable = variable ** value to get a new "
+ "Tensor object.")
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index fd14740a00..a10e1963d1 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -35,12 +35,12 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@@ -321,6 +321,7 @@ def _reverse_seq(input_seq, lengths):
return results
+@tf_export("nn.bidirectional_dynamic_rnn")
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
initial_state_fw=None, initial_state_bw=None,
dtype=None, parallel_iterations=None,
@@ -450,6 +451,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
return (outputs, output_states)
+@tf_export("nn.dynamic_rnn")
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
@@ -723,6 +725,8 @@ def _dynamic_rnn_loop(cell,
if sequence_length is not None:
min_sequence_length = math_ops.reduce_min(sequence_length)
max_sequence_length = math_ops.reduce_max(sequence_length)
+ else:
+ max_sequence_length = time_steps
time = array_ops.constant(0, dtype=dtypes.int32, name="time")
@@ -807,28 +811,21 @@ def _dynamic_rnn_loop(cell,
return (time + 1, output_ta_t, new_state)
- # TODO(pbar) `loop_bound` can be reduced to `max_sequence_length` once
- # TensorArray shape inference is working. When sequence lengths are highly
- # variable, this will reduce the performance overheads of padding to a fixed
- # maximum length.
- loop_bound = time_steps
-
- # This is a workaround since we cannot currently use maximum_iterations if
- # time_steps is defined inside control flow, see the comment in
- # control_flow_ops.py.
- if (context.in_eager_mode() or
- not (control_flow_util.IsInWhileLoop(time_steps.op) or
- control_flow_util.IsInCond(time_steps.op))):
- maximum_iterations = time_steps
+ if in_graph_mode:
+ # Make sure that we run at least 1 step, if necessary, to ensure
+ # the TensorArrays pick up the dynamic shape.
+ loop_bound = math_ops.minimum(
+ time_steps, math_ops.maximum(1, max_sequence_length))
else:
- maximum_iterations = None
+ # Using max_sequence_length isn't currently supported in the Eager branch.
+ loop_bound = time_steps
_, output_final_ta, final_state = control_flow_ops.while_loop(
cond=lambda time, *_: time < loop_bound,
body=_time_step,
loop_vars=(time, output_ta, state),
parallel_iterations=parallel_iterations,
- maximum_iterations=maximum_iterations,
+ maximum_iterations=time_steps,
swap_memory=swap_memory)
# Unpack final output if not using output tuples.
@@ -850,6 +847,7 @@ def _dynamic_rnn_loop(cell,
return (final_outputs, final_state)
+@tf_export("nn.raw_rnn")
def raw_rnn(cell, loop_fn,
parallel_iterations=None, swap_memory=False, scope=None):
"""Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
@@ -1157,6 +1155,7 @@ def raw_rnn(cell, loop_fn,
return (emit_ta, final_state, final_loop_state)
+@tf_export("nn.static_rnn")
def static_rnn(cell,
inputs,
initial_state=None,
@@ -1326,6 +1325,7 @@ def static_rnn(cell,
return (outputs, state)
+@tf_export("nn.static_state_saving_rnn")
def static_state_saving_rnn(cell,
inputs,
state_saver,
@@ -1410,6 +1410,7 @@ def static_state_saving_rnn(cell,
return (outputs, state)
+@tf_export("nn.static_bidirectional_rnn")
def static_bidirectional_rnn(cell_fw,
cell_bw,
inputs,
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index b41aff76d4..f1ac3e9baf 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -47,6 +47,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
_BIAS_VARIABLE_NAME = "bias"
@@ -133,6 +134,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
return nest.map_structure(get_state_shape, state_size)
+@tf_export("nn.rnn_cell.RNNCell")
class RNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
@@ -294,6 +296,7 @@ class _LayerRNNCell(RNNCell):
*args, **kwargs)
+@tf_export("nn.rnn_cell.BasicRNNCell")
class BasicRNNCell(_LayerRNNCell):
"""The most basic RNN cell.
@@ -351,6 +354,7 @@ class BasicRNNCell(_LayerRNNCell):
return output, output
+@tf_export("nn.rnn_cell.GRUCell")
class GRUCell(_LayerRNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
@@ -448,6 +452,7 @@ class GRUCell(_LayerRNNCell):
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
+@tf_export("nn.rnn_cell.LSTMStateTuple")
class LSTMStateTuple(_LSTMStateTuple):
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
@@ -467,6 +472,7 @@ class LSTMStateTuple(_LSTMStateTuple):
return c.dtype
+@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(_LayerRNNCell):
"""Basic LSTM recurrent network cell.
@@ -591,6 +597,7 @@ class BasicLSTMCell(_LayerRNNCell):
return new_h, new_state
+@tf_export("nn.rnn_cell.LSTMCell")
class LSTMCell(_LayerRNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -834,6 +841,7 @@ def _default_dropout_state_filter_visitor(substate):
return True
+@tf_export("nn.rnn_cell.DropoutWrapper")
class DropoutWrapper(RNNCell):
"""Operator adding dropout to inputs and outputs of the given cell."""
@@ -980,6 +988,10 @@ class DropoutWrapper(RNNCell):
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
@property
+ def wrapped_cell(self):
+ return self._cell
+
+ @property
def state_size(self):
return self._cell.state_size
@@ -1058,6 +1070,7 @@ class DropoutWrapper(RNNCell):
return output, new_state
+@tf_export("nn.rnn_cell.ResidualWrapper")
class ResidualWrapper(RNNCell):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
@@ -1113,6 +1126,7 @@ class ResidualWrapper(RNNCell):
return (res_outputs, new_state)
+@tf_export("nn.rnn_cell.DeviceWrapper")
class DeviceWrapper(RNNCell):
"""Operator that ensures an RNNCell runs on a particular device."""
@@ -1147,6 +1161,7 @@ class DeviceWrapper(RNNCell):
return self._cell(inputs, state, scope=scope)
+@tf_export("nn.rnn_cell.MultiRNNCell")
class MultiRNNCell(RNNCell):
"""RNN cell composed sequentially of multiple simple cells."""
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index c0c1ade495..4b5072fd67 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -33,6 +33,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops
+from tensorflow.python.util.tf_export import tf_export
class EagerFunc(object):
@@ -243,6 +244,7 @@ def eager_py_func(func, inp, Tout, name=None):
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
+@tf_export("py_func")
def py_func(func, inp, Tout, stateful=True, name=None):
"""Wraps a python function and uses it as a TensorFlow op.
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index dc4d913c93..cedd36c1de 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
def encode_resource_handle(resource_handle):
@@ -141,6 +142,7 @@ class TensorHandle(object):
return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
+@tf_export("get_session_handle")
def get_session_handle(data, name=None):
"""Return the handle of `data`.
@@ -183,6 +185,7 @@ def get_session_handle(data, name=None):
return gen_data_flow_ops._get_session_handle(data, name=name) # pylint: disable=protected-access
+@tf_export("get_session_tensor")
def get_session_tensor(handle, dtype, name=None):
"""Get the tensor of type `dtype` by feeding a tensor handle.
@@ -223,6 +226,7 @@ def get_session_tensor(handle, dtype, name=None):
return (holder, tensor)
+@tf_export("delete_session_tensor")
def delete_session_tensor(handle, name=None):
"""Delete the tensor for the given tensor handle.
diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py
index 6aa9e3419e..b0eecd8a1e 100644
--- a/tensorflow/python/ops/sets_impl.py
+++ b/tensorflow/python/ops/sets_impl.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_set_ops
+from tensorflow.python.util.tf_export import tf_export
_VALID_DTYPES = set([
@@ -30,6 +31,7 @@ _VALID_DTYPES = set([
dtypes.uint8, dtypes.uint16, dtypes.string])
+@tf_export("sets.set_size")
def set_size(a, validate_indices=True):
"""Compute number of unique elements along last dimension of `a`.
@@ -131,6 +133,7 @@ def _set_operation(a, b, set_operation, validate_indices=True):
return sparse_tensor.SparseTensor(indices, values, shape)
+@tf_export("sets.set_intersection")
def set_intersection(a, b, validate_indices=True):
"""Compute set intersection of elements in last dimension of `a` and `b`.
@@ -197,6 +200,7 @@ def set_intersection(a, b, validate_indices=True):
return _set_operation(a, b, "intersection", validate_indices)
+@tf_export("sets.set_difference")
def set_difference(a, b, aminusb=True, validate_indices=True):
"""Compute set difference of elements in last dimension of `a` and `b`.
@@ -267,6 +271,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True):
return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
+@tf_export("sets.set_union")
def set_union(a, b, validate_indices=True):
"""Compute set union of elements in last dimension of `a` and `b`.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index c368d166f5..3224856d7b 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -65,6 +65,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.gen_sparse_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
def _convert_to_sparse_tensor(sp_input):
@@ -108,6 +109,7 @@ def _convert_to_sparse_tensors(sp_inputs):
# pylint: disable=protected-access
+@tf_export("sparse_concat")
def sparse_concat(axis,
sp_inputs,
name=None,
@@ -236,6 +238,7 @@ def sparse_concat(axis,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
+@tf_export("sparse_add")
def sparse_add(a, b, thresh=0):
"""Adds two tensors, at least one of each is a `SparseTensor`.
@@ -463,6 +466,7 @@ def sparse_dense_cwise_add(sp_t, dense_t):
return sparse_tensor.SparseTensor(sp_t.indices, result, sp_t.dense_shape)
+@tf_export("sparse_reorder")
def sparse_reorder(sp_input, name=None):
"""Reorders a `SparseTensor` into the canonical, row-major ordering.
@@ -511,6 +515,7 @@ def sparse_reorder(sp_input, name=None):
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
+@tf_export("sparse_reshape")
def sparse_reshape(sp_input, shape, name=None):
"""Reshapes a `SparseTensor` to represent values in a new dense shape.
@@ -603,6 +608,7 @@ class KeywordRequired(object):
return "KeywordRequired()"
+@tf_export("sparse_split")
def sparse_split(keyword_required=KeywordRequired(),
sp_input=None, num_split=None, axis=None,
name=None, split_dim=None):
@@ -669,6 +675,7 @@ def sparse_split(keyword_required=KeywordRequired(),
return sparse_tensors
+@tf_export("sparse_slice")
def sparse_slice(sp_input, start, size, name=None):
"""Slice a `SparseTensor` based on the `start` and `size.
@@ -713,6 +720,8 @@ def sparse_slice(sp_input, start, size, name=None):
output_values,
output_shape)
+
+@tf_export("sparse_to_dense")
def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
@@ -768,6 +777,7 @@ def sparse_to_dense(sparse_indices,
name=name)
+@tf_export("sparse_reduce_max")
def sparse_reduce_max(sp_input, axis=None, keep_dims=False,
reduction_axes=None):
"""Computes the max of elements across dimensions of a SparseTensor.
@@ -815,6 +825,7 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False,
keep_dims)
+@tf_export("sparse_reduce_max_sparse")
def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False,
reduction_axes=None):
"""Computes the max of elements across dimensions of a SparseTensor.
@@ -852,6 +863,7 @@ def sparse_reduce_max_sparse(sp_input, axis=None, keep_dims=False,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
+@tf_export("sparse_reduce_sum")
def sparse_reduce_sum(sp_input, axis=None, keep_dims=False,
reduction_axes=None):
"""Computes the sum of elements across dimensions of a SparseTensor.
@@ -899,6 +911,7 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False,
keep_dims)
+@tf_export("sparse_reduce_sum_sparse")
def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False,
reduction_axes=None):
"""Computes the sum of elements across dimensions of a SparseTensor.
@@ -936,6 +949,7 @@ def sparse_reduce_sum_sparse(sp_input, axis=None, keep_dims=False,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
+@tf_export("sparse_tensor_to_dense")
def sparse_tensor_to_dense(sp_input,
default_value=0,
validate_indices=True,
@@ -987,6 +1001,7 @@ def sparse_tensor_to_dense(sp_input,
name=name)
+@tf_export("sparse_to_indicator")
def sparse_to_indicator(sp_input, vocab_size, name=None):
"""Converts a `SparseTensor` of ids into a dense bool indicator tensor.
@@ -1049,6 +1064,7 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
sp_new, default_value=False, validate_indices=False, name=name)
+@tf_export("sparse_merge")
def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
already_sorted=False):
"""Combines a batch of feature ids and values into a single `SparseTensor`.
@@ -1189,6 +1205,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
return result if already_sorted else sparse_reorder(result)
+@tf_export("sparse_retain")
def sparse_retain(sp_input, to_retain):
"""Retains specified non-empty values within a `SparseTensor`.
@@ -1232,6 +1249,7 @@ def sparse_retain(sp_input, to_retain):
array_ops.identity(sp_input.dense_shape))
+@tf_export("sparse_reset_shape")
def sparse_reset_shape(sp_input, new_shape=None):
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
@@ -1333,6 +1351,7 @@ def sparse_reset_shape(sp_input, new_shape=None):
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
+@tf_export("sparse_fill_empty_rows")
def sparse_fill_empty_rows(sp_input, default_value, name=None):
"""Fills empty rows in the input 2-D `SparseTensor` with a default value.
@@ -1396,6 +1415,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
empty_row_indicator)
+@tf_export("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@@ -1421,6 +1441,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
out_type=out_type)
+@tf_export("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@@ -1521,6 +1542,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
+@tf_export("deserialize_many_sparse")
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.
@@ -1590,6 +1612,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
+@tf_export("sparse_tensor_dense_matmul")
def sparse_tensor_dense_matmul(sp_a,
b,
adjoint_a=False,
@@ -1806,6 +1829,7 @@ def sparse_tensor_dense_matmul(sp_a,
adjoint_b=adjoint_b)
+@tf_export("sparse_softmax")
def sparse_softmax(sp_input, name=None):
"""Applies softmax to a batched N-D `SparseTensor`.
@@ -1860,6 +1884,7 @@ def sparse_softmax(sp_input, name=None):
sp_input.indices, out_vals, sp_input.dense_shape)
+@tf_export("sparse_maximum")
def sparse_maximum(sp_a, sp_b, name=None):
"""Returns the element-wise max of two SparseTensors.
@@ -1896,6 +1921,7 @@ def sparse_maximum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
+@tf_export("sparse_minimum")
def sparse_minimum(sp_a, sp_b, name=None):
"""Returns the element-wise min of two SparseTensors.
@@ -1932,6 +1958,7 @@ def sparse_minimum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
+@tf_export("sparse_transpose")
def sparse_transpose(sp_input, perm=None, name=None):
"""Transposes a `SparseTensor`
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index fe3f734322..15127862a4 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -31,9 +31,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
+@tf_export('lbeta')
def lbeta(x, name='lbeta'):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
@@ -82,6 +84,7 @@ def lbeta(x, name='lbeta'):
return result
+@tf_export('einsum', 'linalg.einsum')
def einsum(equation, *inputs, **kwargs):
"""A generalized contraction between tensors of arbitrary dimension.
@@ -152,27 +155,24 @@ def einsum(equation, *inputs, **kwargs):
indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
"""
- name = kwargs.pop("name", None)
+ name = kwargs.pop('name', None)
if kwargs:
- raise TypeError("invalid keyword arguments for this function: " +
- ", ".join([format(key)
- for key in sorted(list(kwargs.keys()))]))
- with ops.name_scope(name, "einsum", [equation, inputs]) as name:
+ raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
+ [format(key) for key in sorted(list(kwargs.keys()))]))
+ with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
if '...' in equation:
raise ValueError('Subscripts with ellipses are not yet supported.')
match = re.match('([a-z,]+)(->[a-z]*)?', equation)
if not match:
- raise ValueError(
- 'Indices have incorrect format: %s' % equation
- )
+ raise ValueError('Indices have incorrect format: %s' % equation)
inputs = list(inputs)
input_axis_labels = match.group(1).split(',')
if len(inputs) != len(input_axis_labels):
- raise ValueError('Got %d arguments for equation "%s", expecting %d' % (
- len(inputs), equation, len(input_axis_labels)))
+ raise ValueError('Got %d arguments for equation "%s", expecting %d' %
+ (len(inputs), equation, len(input_axis_labels)))
axis_labels = set(''.join(input_axis_labels))
if match.group(2):
@@ -185,10 +185,8 @@ def einsum(equation, *inputs, **kwargs):
for ax in axes_:
counts[ax] += 1
- output_axis_labels = ''.join(sorted(
- ax for ax in indices
- if counts[ax] == 1
- ))
+ output_axis_labels = ''.join(
+ sorted(ax for ax in indices if counts[ax] == 1))
for a in axis_labels:
input_count = sum(1 for s in input_axis_labels if a in s)
@@ -200,22 +198,23 @@ def einsum(equation, *inputs, **kwargs):
temp = inputs[0]
temp_axis_labels = input_axis_labels[0]
- for i in xrange(len(inputs)-1):
- axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1])
- - set(output_axis_labels))
- temp, temp_axis_labels = _einsum_reduction(temp,
- temp_axis_labels,
- inputs[i+1],
- input_axis_labels[i+1],
- axes_to_sum)
+ for i in xrange(len(inputs) - 1):
+ axes_to_sum = (
+ set(temp_axis_labels) &
+ set(input_axis_labels[i + 1]) - set(output_axis_labels))
+ temp, temp_axis_labels = _einsum_reduction(
+ temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
+ axes_to_sum)
missing_indices = set(temp_axis_labels) - set(output_axis_labels)
if missing_indices:
- reduction_indices = [i for i, a in enumerate(temp_axis_labels)
- if a not in output_axis_labels]
+ reduction_indices = [
+ i for i, a in enumerate(temp_axis_labels)
+ if a not in output_axis_labels
+ ]
temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
- temp_axis_labels = ''.join(a for a in temp_axis_labels
- if a in output_axis_labels)
+ temp_axis_labels = ''.join(
+ a for a in temp_axis_labels if a in output_axis_labels)
if sorted(temp_axis_labels) != sorted(output_axis_labels):
raise ValueError('Invalid equation: %s' % equation)
@@ -293,8 +292,10 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
return (1, a)
axis_labels = [t0_axis_labels, t1_axis_labels]
- sorted_axes = [sorted(sym_list, key=lambda a: sort_key(i, a))
- for i, sym_list in enumerate(axis_labels)]
+ sorted_axes = [
+ sorted(sym_list, key=lambda a: sort_key(i, a))
+ for i, sym_list in enumerate(axis_labels)
+ ]
inputs = [t0, t1]
for i, axes_str in enumerate(axis_labels):
perm = [axes_str.find(a) for a in sorted_axes[i]]
@@ -322,30 +323,30 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
num_broadcast_elements_t0 = _total_size(
t0_shape[len(preserved_axes):-len(axes_to_sum)])
num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
- new_shape = (t0_shape[:len(preserved_axes)]
- + [num_broadcast_elements_t0, num_summed_elements])
+ new_shape = (
+ t0_shape[:len(preserved_axes)] +
+ [num_broadcast_elements_t0, num_summed_elements])
t0 = _reshape_if_necessary(t0, new_shape)
t1_shape = _get_shape(t1)
num_broadcast_elements_t1 = _total_size(
- t1_shape[len(preserved_axes)+len(axes_to_sum):])
- new_shape = (t1_shape[:len(preserved_axes)]
- + [num_summed_elements, num_broadcast_elements_t1])
+ t1_shape[len(preserved_axes) + len(axes_to_sum):])
+ new_shape = (
+ t1_shape[:len(preserved_axes)] +
+ [num_summed_elements, num_broadcast_elements_t1])
t1 = _reshape_if_necessary(t1, new_shape)
product = math_ops.matmul(t0, t1)
# Undo compaction of broadcast axes
uncompacted_shape = (
- t0_shape[:len(preserved_axes)+len(broadcast_axes[0])]
- + t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
- )
+ t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] +
+ t1_shape[len(t1_shape) - len(broadcast_axes[1]):])
product = _reshape_if_necessary(product, uncompacted_shape)
product_axes = (
- sorted_axes[0][:len(preserved_axes)+len(broadcast_axes[0])] +
- sorted_axes[1][len(sorted_axes[1])-len(broadcast_axes[1]):]
- )
+ sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] +
+ sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):])
return product, ''.join(product_axes)
@@ -399,13 +400,11 @@ def _total_size(shape_values):
def _exponential_space_einsum(equation, *inputs):
"""Fallback implementation that supports summing an index over > 2 inputs."""
if '...' in equation:
- raise ValueError("Subscripts with ellipses are not yet supported.")
+ raise ValueError('Subscripts with ellipses are not yet supported.')
match = re.match('([a-z,]+)(->[a-z]*)?', equation)
if not match:
- raise ValueError(
- 'Indices have incorrect format: %s' % equation
- )
+ raise ValueError('Indices have incorrect format: %s' % equation)
inputs = list(inputs)
idx_in = match.group(1).split(',')
@@ -422,21 +421,15 @@ def _exponential_space_einsum(equation, *inputs):
for ax in axes_:
counts[ax] += 1
- idx_out = ''.join(sorted(
- ax for ax in indices
- if counts[ax] == 1
- ))
+ idx_out = ''.join(sorted(ax for ax in indices if counts[ax] == 1))
if len(idx_in) != len(inputs):
- raise ValueError(
- 'Expected %d inputs but got %d' % (len(idx_in), len(inputs))
- )
+ raise ValueError('Expected %d inputs but got %d' % (len(idx_in),
+ len(inputs)))
missing_idx = set(idx_out).difference(idx_all)
if missing_idx:
- raise ValueError(
- 'Unknown output axes: %s' % missing_idx
- )
+ raise ValueError('Unknown output axes: %s' % missing_idx)
axis_order = {}
for ax in indices:
@@ -449,18 +442,17 @@ def _exponential_space_einsum(equation, *inputs):
for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
if input_.get_shape().ndims != len(axes_):
raise ValueError(
- 'Input %d with axes %s has incorrect' \
- ' number of dimensions (expected %d, got %d)' % (
- i, axes_, len(axes_), input_.get_shape().ndims
- )
+ 'Input %d with axes %s has incorrect' \
+ ' number of dimensions (expected %d, got %d)' % (
+ i, axes_, len(axes_), input_.get_shape().ndims
+ )
)
sorted_idx = sorted(axes_, key=axis_order.get)
if len(set(axes_)) != len(axes_):
raise ValueError(
- 'Subscript not supported: an axis appears more than once: %s' % axes_
- )
+ 'Subscript not supported: an axis appears more than once: %s' % axes_)
if list(axes_) != sorted_idx:
permuted = [axes_.find(ax) for ax in sorted_idx]
@@ -484,16 +476,15 @@ def _exponential_space_einsum(equation, *inputs):
dims.append(dim)
if len(set(dims)) > 1:
- raise ValueError(
- 'Dimension mismatch on axis: %s' % ax
- )
+ raise ValueError('Dimension mismatch on axis: %s' % ax)
if ax not in idx_out:
reduction_idx.append(j)
# reshape, multiply
- expanded_inputs = [array_ops.reshape(input_, shape)
- for input_, shape in zip(inputs, shapes)]
+ expanded_inputs = [
+ array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes)
+ ]
expanded_output = 1
for input_ in expanded_inputs:
expanded_output *= input_
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index c1a66717d8..2c212f4548 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -39,8 +39,9 @@ class LBetaTest(test.TestCase):
x_one_half = [2, 1.]
with self.test_session(use_gpu=True):
self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval())
- self.assertAllClose(
- 0.5, math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose(0.5,
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual([], special_math_ops.lbeta(x_one).get_shape())
def test_one_dimensional_arg_dynamic(self):
@@ -70,8 +71,9 @@ class LBetaTest(test.TestCase):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose(
- [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose([0.5, 0.5],
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape())
def test_two_dimensional_arg_dynamic(self):
@@ -86,10 +88,12 @@ class LBetaTest(test.TestCase):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose(
- [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose([0.5, 0.5],
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual(
- (2,), array_ops.shape(special_math_ops.lbeta(x_one_half)).eval())
+ (2,),
+ array_ops.shape(special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual(
tensor_shape.TensorShape([2]),
special_math_ops.lbeta(x_one_half).get_shape())
@@ -97,8 +101,8 @@ class LBetaTest(test.TestCase):
def test_complicated_shape(self):
with self.test_session(use_gpu=True):
x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
- self.assertAllEqual(
- (3, 2), array_ops.shape(special_math_ops.lbeta(x)).eval())
+ self.assertAllEqual((3, 2),
+ array_ops.shape(special_math_ops.lbeta(x)).eval())
self.assertEqual(
tensor_shape.TensorShape([3, 2]),
special_math_ops.lbeta(x).get_shape())
@@ -155,7 +159,6 @@ class EinsumTest(test.TestCase):
'ijk->i',
'ijk->kji',
'ji,kj->ik',
-
'ikl,kji->kl',
'klj,lki->ij',
'ijk,ilj->kli',
@@ -164,7 +167,6 @@ class EinsumTest(test.TestCase):
'i,ijk,j->k',
'ij,ij,jk,kl->il',
'ij,kj,il,jm->ml',
-
'a,ab,abc->abc',
'a,b,ab->ab',
'ab,ab,c->',
@@ -173,25 +175,21 @@ class EinsumTest(test.TestCase):
'ab,ab,cd,cd->ac',
'ab,ab,cd,cd->cd',
'ab,ab,cd,cd,ef,ef->',
-
'ab,cd,ef->abcdef',
'ab,cd,ef->acdf',
'ab,cd,de->abcde',
'ab,cd,de->be',
'ab,bcd,cd->abcd',
'ab,bcd,cd->abd',
-
'eb,cb,fb->cef',
'abcd,ad',
'bd,db,eac->ace',
'ba,ac,da->bcd',
-
'ab,ab',
'ab,ba',
'abc,abc',
'abc,bac',
'abc,cba',
-
'dba,ead,cad->bce',
'aef,fbc,dca->bde',
]
@@ -234,10 +232,8 @@ class EinsumTest(test.TestCase):
def test_invalid(self):
for axes in self.invalid_cases:
inputs = [
- array_ops.placeholder(
- dtypes.float32, shape=(3, 4)),
- array_ops.placeholder(
- dtypes.float32, shape=(3, 4)),
+ array_ops.placeholder(dtypes.float32, shape=(3, 4)),
+ array_ops.placeholder(dtypes.float32, shape=(3, 4)),
]
with self.assertRaises(ValueError):
_ = special_math_ops.einsum(axes, *inputs)
@@ -245,16 +241,22 @@ class EinsumTest(test.TestCase):
def test_invalid_keyword_arguments(self):
m0 = array_ops.placeholder(dtypes.int32, shape=(1, None))
m1 = array_ops.placeholder(dtypes.int32, shape=(None, 1))
- with self.assertRaisesRegexp(TypeError,
+ with self.assertRaisesRegexp(
+ TypeError,
'invalid keyword arguments for this function: invalid1, invalid2'):
- _ = special_math_ops.einsum('ij,jk->ik', m0, m1, name="name",
- invalid1="value1", invalid2="value2")
+ _ = special_math_ops.einsum(
+ 'ij,jk->ik',
+ m0,
+ m1,
+ name='name',
+ invalid1='value1',
+ invalid2='value2')
def test_dim_mismatch(self):
for axes, input_shapes in self.dim_mismatch_cases:
inputs = [
- array_ops.placeholder(
- dtypes.float32, shape=shape) for shape in input_shapes
+ array_ops.placeholder(dtypes.float32, shape=shape)
+ for shape in input_shapes
]
with self.assertRaises(ValueError):
_ = special_math_ops.einsum(axes, *inputs)
@@ -291,8 +293,8 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [[2], [1], [1]],
}
- np.testing.assert_almost_equal(
- [[7]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[7]], sess.run(
+ out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3))
@@ -312,11 +314,11 @@ class EinsumTest(test.TestCase):
out = special_math_ops.einsum('ijk,kl->ijl', m0, m1)
with session.Session() as sess:
feed_dict = {
- m0: [[[1,2]]],
+ m0: [[[1, 2]]],
m1: [[3], [2]],
}
- np.testing.assert_almost_equal(
- [[[7]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7]]],
+ sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
@@ -325,10 +327,10 @@ class EinsumTest(test.TestCase):
with session.Session() as sess:
feed_dict = {
m0: [[3], [2]],
- m1: [[[1,2]]],
+ m1: [[[1, 2]]],
}
- np.testing.assert_almost_equal(
- [[[7]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7]]],
+ sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
@@ -339,8 +341,8 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [3, 2],
}
- np.testing.assert_almost_equal(
- [[7]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[7]], sess.run(
+ out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
@@ -351,8 +353,8 @@ class EinsumTest(test.TestCase):
m0: [[[[1, 2]], [[2, 1]]]],
m1: [[3, 2]],
}
- np.testing.assert_almost_equal(
- [[[7, 8]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7, 8]]],
+ sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 69f868c67a..a579688276 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
def _infer_fft_length_for_rfft(input_tensor, fft_rank):
@@ -164,11 +165,17 @@ ifft2d = gen_spectral_ops.ifft2d
fft3d = gen_spectral_ops.fft3d
ifft3d = gen_spectral_ops.ifft3d
rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft")
+tf_export("spectral.rfft")(rfft)
irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft")
+tf_export("spectral.irfft")(irfft)
rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d")
+tf_export("spectral.rfft2d")(rfft2d)
irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d")
+tf_export("spectral.irfft2d")(irfft2d)
rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d")
+tf_export("spectral.rfft3d")(rfft3d)
irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
+tf_export("spectral.irfft3d")(irfft3d)
def _validate_dct_arguments(dct_type, n, axis, norm):
@@ -184,6 +191,7 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
+@tf_export("spectral.dct")
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index dee495f78f..3cc76fdbf3 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -89,6 +89,7 @@ from tensorflow.python.ops import gen_state_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_state_ops import *
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -189,6 +190,7 @@ def is_variable_initialized(ref, name=None):
name=name)
+@tf_export("assign_sub")
def assign_sub(ref, value, use_locking=None, name=None):
"""Update 'ref' by subtracting 'value' from it.
@@ -217,6 +219,7 @@ def assign_sub(ref, value, use_locking=None, name=None):
return ref.assign_sub(value)
+@tf_export("assign_add")
def assign_add(ref, value, use_locking=None, name=None):
"""Update 'ref' by adding 'value' to it.
@@ -245,6 +248,7 @@ def assign_add(ref, value, use_locking=None, name=None):
return ref.assign_add(value)
+@tf_export("assign")
def assign(ref, value, validate_shape=None, use_locking=None, name=None):
"""Update 'ref' by assigning 'value' to it.
@@ -277,6 +281,7 @@ def assign(ref, value, validate_shape=None, use_locking=None, name=None):
return ref.assign(value)
+@tf_export("count_up_to")
def count_up_to(ref, limit, name=None):
r"""Increments 'ref' until it reaches 'limit'.
@@ -299,6 +304,7 @@ def count_up_to(ref, limit, name=None):
ref.handle, limit, T=ref.dtype, name=name)
+@tf_export("scatter_update")
def scatter_update(ref, indices, updates, use_locking=True, name=None):
# pylint: disable=line-too-long
r"""Applies sparse updates to a variable reference.
@@ -354,6 +360,7 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None):
return ref.read_value()
+@tf_export("scatter_nd_update")
def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
r"""Applies sparse `updates` to individual values or slices in a Variable.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index f30e79a108..b8c39d91b4 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -47,9 +47,11 @@ from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_string_ops import *
from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@@ -120,6 +122,7 @@ def _reduce_join_reduction_dims(x, axis, reduction_indices):
return math_ops.range(array_ops.rank(x) - 1, -1, -1)
+@tf_export("reduce_join")
def reduce_join(inputs, axis=None,
keep_dims=False,
separator="",
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
index 2cf2eda16e..7f4f4ce5ab 100644
--- a/tensorflow/python/ops/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops.py
@@ -25,9 +25,11 @@ from tensorflow.python.ops import summary_op_util
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
+from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+@tf_export("summary.tensor_summary")
def tensor_summary(name,
tensor,
summary_description=None,
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 99a71cbe79..84449e00be 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -29,11 +29,13 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
__all__ = ["make_template"]
+@tf_export("make_template")
def make_template(name_, func_, create_scope_now_=False, unique_name_=None,
custom_getter_=None, **kwargs):
"""Given an arbitrary function, wrap it so that it does variable sharing.
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 398521c9b5..5cdf03509e 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_should_use
+from tensorflow.python.util.tf_export import tf_export
# _GraphTensorArray accesses many of the hidden generated ops, but is in
@@ -711,6 +712,7 @@ class _EagerTensorArray(object):
# TensorArray is designed to hide an underlying implementation object
# and as such accesses many of that object's hidden fields.
# pylint: disable=protected-access
+@tf_export("TensorArray")
class TensorArray(object):
"""Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 3a39af8e20..81565a6377 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -27,6 +27,7 @@ import sys
import traceback
import six
+from six import iteritems
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
@@ -40,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.tf_export import tf_export
__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
"get_variable", "get_local_variable", "variable_scope",
@@ -186,6 +188,7 @@ class _ReuseMode(enum.Enum):
# REUSE_TRUE = 3
AUTO_REUSE = _ReuseMode.AUTO_REUSE
+tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
get_variable() should create the requested variable if it doesn't exist or, if
@@ -768,8 +771,8 @@ class _VariableStore(object):
if initializer is None:
initializer, initializing_from_value = self._get_default_initializer(
name=name, shape=shape, dtype=dtype)
- # Clear control dependencies while creating the initializer.
- with ops.control_dependencies(None):
+ # Enter an init scope when creating the initializer.
+ with ops.init_scope():
if initializing_from_value:
init_val = initializer
variable_dtype = None
@@ -785,26 +788,16 @@ class _VariableStore(object):
if use_resource is None:
# Set the default value if unspecified.
use_resource = False
- if use_resource:
- v = resource_variable_ops.ResourceVariable(
- initial_value=init_val,
- name=name,
- trainable=trainable,
- collections=collections,
- caching_device=caching_device,
- dtype=variable_dtype,
- validate_shape=validate_shape,
- constraint=constraint)
- else:
- v = variables.Variable(
- initial_value=init_val,
- name=name,
- trainable=trainable,
- collections=collections,
- caching_device=caching_device,
- dtype=variable_dtype,
- validate_shape=validate_shape,
- constraint=constraint)
+ v = variable(
+ initial_value=init_val,
+ name=name,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ dtype=variable_dtype,
+ validate_shape=validate_shape,
+ constraint=constraint,
+ use_resource=use_resource)
if context.in_graph_mode() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
@@ -863,12 +856,14 @@ class _VariableStore(object):
# To stop regularization, use this regularizer
+@tf_export("no_regularizer")
def no_regularizer(_):
"""Use this function to prevent regularization of variables."""
return None
# TODO(alive): support caching devices and partitioned variables in Eager mode.
+@tf_export("VariableScope")
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@@ -1168,6 +1163,7 @@ _VARSTORE_KEY = ("__variable_store",)
_VARSCOPE_KEY = ("__varscope",)
+@tf_export("get_variable_scope")
def get_variable_scope():
"""Returns the current variable scope."""
scope = ops.get_collection(_VARSCOPE_KEY)
@@ -1247,7 +1243,38 @@ class EagerVariableStore(object):
key=lambda x: x.name)
# pylint: enable=protected-access
+ def copy(self):
+ """Copy this variable store and all of its contents.
+
+ Variables contained in this store will be copied over to the new variable
+ store, meaning that they can be modified without affecting the variables in
+ this store.
+
+ Returns:
+ A new EagerVariableStore instance containing copied variables.
+ """
+ # pylint: disable=protected-access
+ new_store = EagerVariableStore()
+ for key, var in iteritems(self._store._vars):
+ # Strip device out of variable name.
+ try:
+ index = var.name.index(":")
+ except ValueError:
+ stripped_var_name = var.name
+ else:
+ stripped_var_name = var.name[:index]
+
+ # Create new variable with same value, name, and "trainable" flag.
+ new_var = resource_variable_ops.ResourceVariable(
+ var.read_value(),
+ name=stripped_var_name,
+ trainable=var._trainable)
+ new_store._store._vars[key] = new_var
+ return new_store
+ # pylint: enable=protected-access
+
+@tf_export("get_variable")
def get_variable(name,
shape=None,
dtype=None,
@@ -1359,6 +1386,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
@functools.wraps(get_variable)
+@tf_export("get_local_variable")
def get_local_variable(*args, **kwargs):
kwargs["trainable"] = False
if "collections" in kwargs:
@@ -1673,7 +1701,8 @@ def _get_unique_variable_scope(prefix):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-class variable_scope(object): # pylint: disable=invalid-name
+@tf_export("variable_scope") # pylint: disable=invalid-name
+class variable_scope(object):
"""A context manager for defining ops that creates variables (layers).
This context manager validates that the (optional) `values` are from the same
@@ -2006,6 +2035,7 @@ class variable_scope(object): # pylint: disable=invalid-name
# pylint: disable=g-doc-return-or-yield
+@tf_export("variable_op_scope")
@tf_contextlib.contextmanager
def variable_op_scope(values,
name_or_scope,
@@ -2067,21 +2097,26 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
return slice_dim, slice_shape
-def variable(initial_value=None,
- trainable=True,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- dtype=None,
- use_resource=None):
+def default_variable_creator(next_creator=None, **kwargs):
+ """Default variable creator."""
+ assert next_creator is None
+ initial_value = kwargs.get("initial_value", None)
+ trainable = kwargs.get("trainable", True)
+ collections = kwargs.get("collections", None)
+ validate_shape = kwargs.get("validate_shape", True)
+ caching_device = kwargs.get("caching_device", None)
+ name = kwargs.get("name", None)
+ dtype = kwargs.get("dtype", None)
+ constraint = kwargs.get("constraint", None)
+ use_resource = kwargs.get("use_resource", None)
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.in_eager_mode()):
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
- caching_device=caching_device, name=name, dtype=dtype)
+ caching_device=caching_device, name=name, dtype=dtype,
+ constraint=constraint)
elif not use_resource and context.in_eager_mode():
raise RuntimeError(
"VariableScope should use resource variable when eager execution is"
@@ -2091,4 +2126,95 @@ def variable(initial_value=None,
return variables.Variable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
- caching_device=caching_device, name=name, dtype=dtype)
+ caching_device=caching_device, name=name, dtype=dtype,
+ constraint=constraint)
+
+
+def _make_getter(captured_getter, captured_previous):
+ """Gets around capturing loop variables in python being broken."""
+ return lambda **kwargs: captured_getter(captured_previous, **kwargs)
+
+
+def variable(initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ dtype=None,
+ constraint=None,
+ use_resource=None):
+ previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
+ for getter in ops.get_default_graph()._get_variable_creator_stack(): # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+ return previous_getter(initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name, dtype=dtype,
+ constraint=constraint,
+ use_resource=use_resource)
+
+
+@tf_contextlib.contextmanager
+def variable_creator_scope(variable_creator):
+ """Scope which defines a variable creation function to be used by variable().
+
+ variable_creator is expected to be a function with the following signature:
+
+ ```
+ def variable_creator(next_creator, **kwargs)
+ ```
+
+ The creator is supposed to eventually call the next_creator to create a
+ variable if it does want to create a variable and not call Variable or
+ ResourceVariable directly. This helps make creators composable. A creator may
+ choose to create multiple variables, return already existing variables, or
+ simply register that a variable was created and defer to the next creators in
+ line. Creators can also modify the keyword arguments seen by the next
+ creators.
+
+ Custom getters in the variable scope will eventually resolve down to these
+ custom creators when they do create variables.
+
+ The valid keyword arguments in kwds are:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ constraint: A constraint function to be applied to the variable after
+ updates by some algorithms.
+ use_resource: if True, a ResourceVariable is always created.
+
+ This set may grow over time, so it's important the signature of creators is as
+ mentioned above.
+
+ Args:
+ variable_creator: the passed creator
+
+ Yields:
+ A scope in which the creator is active
+ """
+ with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
+ yield
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index b25855633e..19e3298e40 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -28,11 +28,14 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("Variable")
class Variable(object):
"""See the @{$variables$Variables How To} for a high level overview.
@@ -209,6 +212,7 @@ class Variable(object):
if not context.in_graph_mode():
raise RuntimeError("tf.Variable not supported in Eager mode. "
"Please use tfe.Variable instead")
+ self._in_graph_mode = context.in_graph_mode()
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
if initial_value:
@@ -304,7 +308,7 @@ class Variable(object):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
- with ops.control_dependencies(None):
+ with ops.init_scope():
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
@@ -375,8 +379,8 @@ class Variable(object):
else:
with ops.colocate_with(self._variable.op):
self._snapshot = array_ops.identity(self._variable, name="read")
+ ops.add_to_collections(collections, self)
- ops.add_to_collections(collections, self)
self._caching_device = caching_device
self._save_slice_info = None
self._constraint = constraint
@@ -550,7 +554,7 @@ class Variable(object):
A `Tensor` holding the value of this variable after its initializer
has run.
"""
- with ops.control_dependencies(None):
+ with ops.init_scope():
return control_flow_ops.cond(is_variable_initialized(self),
self.read_value,
lambda: self.initial_value)
@@ -1019,6 +1023,61 @@ class Variable(object):
return Variable(variable_def=variable_def,
import_scope=import_scope)
+ def __iadd__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable += will be deprecated. Use variable.assign_add"
+ " if you want assignment to the variable value or 'x = x + y'"
+ " if you want a new python Tensor object.", 1)
+ return self + other
+
+ def __isub__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable -= will be deprecated. Use variable.assign_sub"
+ " if you want assignment to the variable value or 'x = x - y'"
+ " if you want a new python Tensor object.", 1)
+ return self - other
+
+ def __imul__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable *= will be deprecated. Use variable.assign_mul"
+ " if you want assignment to the variable value or 'x = x * y'"
+ " if you want a new python Tensor object.", 1)
+ return self * other
+
+ def __idiv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __itruediv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __irealdiv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __ipow__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable **= will be deprecated. Use 'x = x ** y'"
+ " if you want a new python Tensor object.", 1)
+ return self ** other
+
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice.
@@ -1308,6 +1367,7 @@ class PartitionedVariable(object):
"assign() has not been implemented for PartitionedVariable.")
+@tf_export("global_variables")
def global_variables(scope=None):
"""Returns global variables.
@@ -1333,6 +1393,7 @@ def global_variables(scope=None):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
+@tf_export("all_variables")
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
@@ -1357,6 +1418,7 @@ def _all_saveable_objects(scope=None):
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
+@tf_export("local_variables")
def local_variables(scope=None):
"""Returns local variables.
@@ -1384,6 +1446,7 @@ def local_variables(scope=None):
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
+@tf_export("model_variables")
def model_variables(scope=None):
"""Returns all variables in the MODEL_VARIABLES collection.
@@ -1400,6 +1463,7 @@ def model_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
+@tf_export("trainable_variables")
def trainable_variables(scope=None):
"""Returns all variables created with `trainable=True`.
@@ -1421,6 +1485,7 @@ def trainable_variables(scope=None):
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
+@tf_export("moving_average_variables")
def moving_average_variables(scope=None):
"""Returns all variables that maintain their moving averages.
@@ -1442,6 +1507,7 @@ def moving_average_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
+@tf_export("initializers.variables", "variables_initializer")
def variables_initializer(var_list, name="init"):
"""Returns an Op that initializes a list of variables.
@@ -1467,6 +1533,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
+@tf_export("initialize_variables")
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
@@ -1474,6 +1541,7 @@ def initialize_variables(var_list, name="init"):
return variables_initializer(var_list, name=name)
+@tf_export("initializers.global_variables", "global_variables_initializer")
def global_variables_initializer():
"""Returns an Op that initializes global variables.
@@ -1487,6 +1555,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
+@tf_export("initialize_all_variables")
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
@@ -1494,6 +1563,7 @@ def initialize_all_variables():
return global_variables_initializer()
+@tf_export("initializers.local_variables", "local_variables_initializer")
def local_variables_initializer():
"""Returns an Op that initializes all local variables.
@@ -1507,6 +1577,7 @@ def local_variables_initializer():
return variables_initializer(local_variables())
+@tf_export("initialize_local_variables")
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
@@ -1514,6 +1585,7 @@ def initialize_local_variables():
return local_variables_initializer()
+@tf_export("is_variable_initialized")
@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -1528,6 +1600,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
+@tf_export("assert_variables_initialized")
@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -1570,6 +1643,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
+@tf_export("report_uninitialized_variables")
@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index 837bca1dbd..12dae94a64 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -33,6 +33,7 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.tf_export import tf_export
# When a subclass of the Benchmark class is created, it is added to
@@ -181,6 +182,7 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
throughput=throughput, extras=extras)
+@tf_export("test.Benchmark")
class TensorFlowBenchmark(Benchmark):
"""Abstract class that provides helpers for TensorFlow benchmarks."""
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index 202475efdf..315889e9aa 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -34,8 +34,10 @@ from tensorflow.python.lib.io.file_io import stat as Stat
from tensorflow.python.lib.io.file_io import walk as Walk
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
+@tf_export('gfile.GFile', 'gfile.Open')
class GFile(_FileIO):
"""File I/O wrappers without thread locking."""
@@ -43,6 +45,7 @@ class GFile(_FileIO):
super(GFile, self).__init__(name=name, mode=mode)
+@tf_export('gfile.FastGFile')
class FastGFile(_FileIO):
"""File I/O wrappers without thread locking."""
diff --git a/tensorflow/python/platform/sysconfig.py b/tensorflow/python/platform/sysconfig.py
index f6c4f2227f..5c50fa023d 100644
--- a/tensorflow/python/platform/sysconfig.py
+++ b/tensorflow/python/platform/sysconfig.py
@@ -29,9 +29,11 @@ import os.path as _os_path
from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG
from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD
from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.util.tf_export import tf_export
# pylint: disable=g-import-not-at-top
+@tf_export('sysconfig.get_include')
def get_include():
"""Get the directory containing the TensorFlow C++ header files.
@@ -46,6 +48,7 @@ def get_include():
return _os_path.join(_os_path.dirname(tf.__file__), 'include')
+@tf_export('sysconfig.get_lib')
def get_lib():
"""Get the directory containing the TensorFlow framework library.
@@ -56,6 +59,7 @@ def get_lib():
return _os_path.join(_os_path.dirname(tf.__file__))
+@tf_export('sysconfig.get_compile_flags')
def get_compile_flags():
"""Get the compilation flags for custom operators.
@@ -69,6 +73,7 @@ def get_compile_flags():
return flags
+@tf_export('sysconfig.get_link_flags')
def get_link_flags():
"""Get the link flags for custom operators.
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index ec280c6e1e..9b7655722a 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -56,6 +56,7 @@ from tensorflow.python.ops.gradient_checker import compute_gradient
# pylint: enable=unused-import,g-bad-import-order
import sys
+from tensorflow.python.util.tf_export import tf_export
if sys.version_info.major == 2:
import mock # pylint: disable=g-import-not-at-top,unused-import
else:
@@ -68,12 +69,14 @@ Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name
+@tf_export('test.main')
def main(argv=None):
"""Runs all unit tests."""
_test_util.InstallStackTraceHandler()
return _googletest.main(argv)
+@tf_export('test.get_temp_dir')
def get_temp_dir():
"""Returns a temporary directory for use during tests.
@@ -85,6 +88,7 @@ def get_temp_dir():
return _googletest.GetTempDir()
+@tf_export('test.test_src_dir_path')
def test_src_dir_path(relative_path):
"""Creates an absolute test srcdir path given a relative path.
@@ -98,6 +102,7 @@ def test_src_dir_path(relative_path):
return _googletest.test_src_dir_path(relative_path)
+@tf_export('test.is_built_with_cuda')
def is_built_with_cuda():
"""Returns whether TensorFlow was built with CUDA (GPU) support."""
return _test_util.IsGoogleCudaEnabled()
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 9153855588..04ba28c219 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -224,15 +224,15 @@ class PrintModelAnalysisTest(test.TestCase):
# pylint: disable=line-too-long
with gfile.Open(outfile, 'r') as f:
lines = f.read().split('\n')
+ self.assertGreater(len(lines), 5)
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
- self.assertEqual(
- compat.as_bytes(
- 'node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.86k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.58k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/129 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'
- ), compat.as_bytes(lib.CheckAndRemoveDoc(result)))
+ self.assertTrue(
+ compat.as_text(lib.CheckAndRemoveDoc(result))
+ .startswith('node name | # parameters | # float_ops'))
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
- self.assertEqual(168863, tfprof_node.total_float_ops)
+ self.assertLess(168800, tfprof_node.total_float_ops)
self.assertEqual(8, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual(
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 083931aa83..3f25311a83 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -24,10 +24,13 @@ limitations under the License.
%rename("%s") TFE_ContextDisableRunMetadata;
%rename("%s") TFE_ContextExportRunMetadata;
%rename("%s") TFE_ContextClearCaches;
+%rename("%s") TFE_ContextGetDevicePlacementPolicy;
+%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_Execute;
+%rename("%s") TFE_Py_FastPathExecute;
%rename("%s") TFE_Py_UID;
%rename("%s") TFE_Py_TapeSetNew;
%rename("%s") TFE_Py_TapeSetRemove;
@@ -118,6 +121,7 @@ limitations under the License.
%rename("%s") TFE_DEVICE_PLACEMENT_EXPLICIT;
%rename("%s") TFE_DEVICE_PLACEMENT_WARN;
%rename("%s") TFE_DEVICE_PLACEMENT_SILENT;
+%rename("%s") TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32;
%include "tensorflow/c/eager/c_api.h"
@@ -155,7 +159,7 @@ limitations under the License.
}
$1 = &temp;
$1->resize(PyInt_AsLong($input), nullptr);
-}
+}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
@@ -180,10 +184,14 @@ limitations under the License.
}
}
+// SWIG usually unwraps the tuple that the native Python/C interface generates.
+// Since we wanted to have a function with a variable length of arguments, we
+// used the native Python/C interface directly (which by default supports
+// passing all arguments as a tuple).
+%native(TFE_Py_FastPathExecute) TFE_Py_FastPathExecute_C;
%include "tensorflow/python/eager/pywrap_tfe.h"
-
// Clear all typemaps.
%typemap(out) TF_DataType;
%typemap(out) int64_t;
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 355593eca5..92c1fcadd2 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -286,12 +286,13 @@ def merge(inputs, collections=None, name=None):
return val
-def merge_all(key=_ops.GraphKeys.SUMMARIES):
+def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None):
"""Merges all summaries collected in the default graph.
Args:
key: `GraphKey` used to collect the summaries. Defaults to
`GraphKeys.SUMMARIES`.
+ scope: Optional scope used to filter the summary ops, using `re.match`
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
@@ -310,7 +311,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES):
raise RuntimeError(
'Merging tf.summary.* ops is not compatible with eager execution. '
'Use tf.contrib.summary instead.')
- summary_ops = _ops.get_collection(key)
+ summary_ops = _ops.get_collection(key, scope=scope)
if not summary_ops:
return None
else:
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 0ddf09260b..a2e86a1c43 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def,
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
- saved_model_tags=None):
+ saved_model_tags=None,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
@@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def,
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
- saver = saver_lib.Saver(saver_def=input_saver_def)
+ saver = saver_lib.Saver(saver_def=input_saver_def,
+ write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
@@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(var_list=var_list)
+ saver = saver_lib.Saver(var_list=var_list,
+ write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
@@ -217,7 +220,8 @@ def freeze_graph(input_graph,
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
- saved_model_tags=tag_constants.SERVING):
+ saved_model_tags=tag_constants.SERVING,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
@@ -236,7 +240,8 @@ def freeze_graph(input_graph,
input_graph_def, input_saver_def, input_checkpoint, output_node_names,
restore_op_name, filename_tensor_name, output_graph, clear_devices,
initializer_nodes, variable_names_whitelist, variable_names_blacklist,
- input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
+ input_meta_graph_def, input_saved_model_dir,
+ saved_model_tags.split(","), checkpoint_version=checkpoint_version)
def main(unused_args):
@@ -246,7 +251,7 @@ def main(unused_args):
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags)
+ FLAGS.saved_model_tags, checkpoint_version=checkpoint_version)
if __name__ == "__main__":
@@ -268,6 +273,11 @@ if __name__ == "__main__":
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
+ "--checkpoint_version",
+ type=int,
+ default=saver_pb2.SaverDef.V2,
+ help="Tensorflow variable file format")
+ parser.add_argument(
"--output_graph",
type=str,
default="",
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index feeed7102c..342732465d 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -86,7 +86,8 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
freeze_graph.freeze_graph(
input_graph_path, input_saver_def_path, input_binary, checkpoint_path,
output_node_names, restore_op_name, filename_tensor_name,
- output_graph_path, clear_devices, "", "", input_meta_graph)
+ output_graph_path, clear_devices, "", "", input_meta_graph,
+ checkpoint_version=saver_write_version)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 8716058e61..dd876cbe7f 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -97,8 +97,9 @@ def parse_numpy_printoption(kv_str):
raise argparse.ArgumentTypeError(
"Setting '%s' from the command line is not supported." % k)
try:
- v = (v_type(v_str) if v_type is not bool
- else flags.BooleanParser().parse(v_str))
+ v = (
+ v_type(v_str)
+ if v_type is not bool else flags.BooleanParser().parse(v_str))
except ValueError as e:
raise argparse.ArgumentTypeError(e.message)
np.set_printoptions(**{k: v})
@@ -121,9 +122,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
- "--file_name", type=str, default="", help="Checkpoint filename. "
- "Note, if using Checkpoint V2 format, file_name is the "
- "shared prefix between all files in the checkpoint.")
+ "--file_name",
+ type=str,
+ default="",
+ help="Checkpoint filename. "
+ "Note, if using Checkpoint V2 format, file_name is the "
+ "shared prefix between all files in the checkpoint.")
parser.add_argument(
"--tensor_name",
type=str,
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index ce64fdf709..21e8e803fc 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -33,6 +33,7 @@ import numpy as np
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
+from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
@@ -377,7 +378,7 @@ def preprocess_input_exprs_arg_string(input_exprs_str):
'input_key=<python expression>'
Returns:
- A dictionary that maps input keys to python expressions.
+ A dictionary that maps input keys to their values.
Raises:
RuntimeError: An error when the given input string is in a bad format.
@@ -388,17 +389,75 @@ def preprocess_input_exprs_arg_string(input_exprs_str):
if '=' not in input_exprs_str:
raise RuntimeError('--input_exprs "%s" format is incorrect. Please follow'
'"<input_key>=<python expression>"' % input_exprs_str)
- input_key, expr = input_raw.split('=')
- input_dict[input_key] = expr
+ input_key, expr = input_raw.split('=', 1)
+ # ast.literal_eval does not work with numpy expressions
+ input_dict[input_key] = eval(expr) # pylint: disable=eval-used
+ return input_dict
+
+def preprocess_input_examples_arg_string(input_examples_str):
+ """Parses input into dict that maps input keys to lists of tf.Example.
+
+ Parses input string in the format of 'input_key1=[{feature_name:
+ feature_list}];input_key2=[{feature_name:feature_list}];' into a dictionary
+ that maps each input_key to its list of serialized tf.Example.
+
+ Args:
+ input_examples_str: A string that specifies a list of dictionaries of
+ feature_names and their feature_lists for each input.
+ Each input is separated by semicolon. For each input key:
+ 'input=[{feature_name1: feature_list1, feature_name2:feature_list2}]'
+ items in feature_list can be the type of float, int, long or str.
+
+ Returns:
+ A dictionary that maps input keys to lists of serialized tf.Example.
+
+ Raises:
+ ValueError: An error when the given tf.Example is not a list.
+ """
+ input_dict = preprocess_input_exprs_arg_string(input_examples_str)
+ for input_key, example_list in input_dict.items():
+ if not isinstance(example_list, list):
+ raise ValueError(
+ 'tf.Example input must be a list of dictionaries, but "%s" is %s' %
+ (example_list, type(example_list)))
+ input_dict[input_key] = [
+ _create_example_string(example) for example in example_list
+ ]
return input_dict
-def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
+def _create_example_string(example_dict):
+ """Create a serialized tf.example from feature dictionary."""
+ example = example_pb2.Example()
+ for feature_name, feature_list in example_dict.items():
+ if not isinstance(feature_list, list):
+ raise ValueError('feature value must be a list, but %s: "%s" is %s' %
+ (feature_name, feature_list, type(feature_list)))
+ if isinstance(feature_list[0], float):
+ example.features.feature[feature_name].float_list.value.extend(
+ feature_list)
+ elif isinstance(feature_list[0], str):
+ example.features.feature[feature_name].bytes_list.value.extend(
+ feature_list)
+ elif isinstance(feature_list[0], (int, long)):
+ example.features.feature[feature_name].int64_list.value.extend(
+ feature_list)
+ else:
+ raise ValueError(
+ 'Type %s for value %s is not supported for tf.train.Feature.' %
+ (type(feature_list[0]), feature_list[0]))
+ return example.SerializeToString()
+
+
+def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
+ input_examples_str):
"""Parses input arg strings and create inputs feed_dict.
Parses '--inputs' string for inputs to be loaded from file, and parses
'--input_exprs' string for inputs to be evaluated from python expression.
+ '--input_examples' string for inputs to be created from tf.example feature
+ dictionary list.
Args:
inputs_str: A string that specified where to load inputs. Each input is
@@ -424,9 +483,11 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
to the specified input tensor, else SavedModel CLI will assume a
dictionary is stored in the pickle file and the value corresponding to
the variable_name will be used.
- input_exprs_str: A string that specified python expressions for inputs.
+ input_exprs_str: A string that specifies python expressions for inputs.
* In the format of: '<input_key>=<python expression>'.
* numpy module is available as np.
+ input_examples_str: A string that specifies tf.Example with dictionary.
+ * In the format of: '<input_key>=<[{feature:value list}]>'
Returns:
A dictionary that maps input tensor keys to numpy ndarrays.
@@ -441,6 +502,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
inputs = preprocess_inputs_arg_string(inputs_str)
input_exprs = preprocess_input_exprs_arg_string(input_exprs_str)
+ input_examples = preprocess_input_examples_arg_string(input_examples_str)
for input_tensor_key, (filename, variable_name) in inputs.items():
data = np.load(filename)
@@ -474,15 +536,20 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str):
tensor_key_feed_dict[input_tensor_key] = data
# When input is a python expression:
- for input_tensor_key, py_expr in input_exprs.items():
+ for input_tensor_key, py_expr_evaluated in input_exprs.items():
if input_tensor_key in tensor_key_feed_dict:
warnings.warn(
'input_key %s has been specified with both --inputs and --input_exprs'
' options. Value in --input_exprs will be used.' % input_tensor_key)
+ tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
- # ast.literal_eval does not work with numpy expressions
- tensor_key_feed_dict[input_tensor_key] = eval(py_expr) # pylint: disable=eval-used
-
+ # When input is a tf.Example:
+ for input_tensor_key, example in input_examples.items():
+ if input_tensor_key in tensor_key_feed_dict:
+ warnings.warn(
+ 'input_key %s has been specified in multiple options. Value in '
+ '--input_examples will be used.' % input_tensor_key)
+ tensor_key_feed_dict[input_tensor_key] = example
return tensor_key_feed_dict
@@ -518,11 +585,12 @@ def run(args):
AttributeError: An error when neither --inputs nor --input_exprs is passed
to run command.
"""
- if not args.inputs and not args.input_exprs:
+ if not args.inputs and not args.input_exprs and not args.input_examples:
raise AttributeError(
- 'At least one of --inputs and --input_exprs must be required')
+ 'At least one of --inputs, --input_exprs or --input_examples must be '
+ 'required')
tensor_key_feed_dict = load_inputs_from_input_arg_string(
- args.inputs, args.input_exprs)
+ args.inputs, args.input_exprs, args.input_examples)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
args.overwrite, tf_debug=args.tf_debug)
@@ -589,10 +657,12 @@ def create_parser():
run_msg = ('Usage example:\n'
'To run input tensors from files through a MetaGraphDef and save'
' the output tensors to files:\n'
- '$saved_model_cli show --dir /tmp/saved_model --tag_set serve'
+ '$saved_model_cli show --dir /tmp/saved_model --tag_set serve '
'--signature_def serving_default '
- '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy'
- '--input_exprs \'input3_key=np.ones(2)\' --outdir=/out\n\n'
+ '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy '
+ '--input_exprs \'input3_key=np.ones(2)\' --input_examples '
+ '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' '
+ '--outdir=/out\n\n'
'For more information about input file format, please see:\n'
'https://www.tensorflow.org/programmers_guide/saved_model_cli\n')
parser_run = subparsers.add_parser(
@@ -620,8 +690,14 @@ def create_parser():
msg = ('Specifying inputs by python expressions, in the format of'
' "<input_key>=\'<python expression>\'", separated by \';\'. '
'numpy module is available as \'np\'. '
- 'Will override duplicate input_keys from --inputs option.')
+ 'Will override duplicate input keys from --inputs option.')
parser_run.add_argument('--input_exprs', type=str, default='', help=msg)
+ msg = (
+ 'Specifying tf.Example inputs as list of dictionaries. For example: '
+ '<input_key>=[{feature0:value_list,feature1:value_list}]. Use ";" to '
+ 'separate input keys. Will override duplicate input keys from --inputs '
+ 'and --input_exprs option.')
+ parser_run.add_argument('--input_examples', type=str, default='', help=msg)
parser_run.add_argument(
'--outdir',
type=str,
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index 0789e1e107..d6cbc49ba1 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -218,8 +218,9 @@ Method name is: tensorflow/serving/predict"""
input_expr_str)
self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3'))
self.assertTrue(input_dict['input2'] == ('file2', None))
- self.assertTrue(input_expr_dict['input3'] == 'np.zeros([2,2])')
- self.assertTrue(input_expr_dict['input4'] == '[4,5]')
+ print(input_expr_dict['input3'])
+ self.assertAllClose(input_expr_dict['input3'], np.zeros([2, 2]))
+ self.assertAllClose(input_expr_dict['input4'], [4, 5])
self.assertTrue(len(input_dict) == 2)
self.assertTrue(len(input_expr_dict) == 2)
@@ -250,7 +251,8 @@ Method name is: tensorflow/serving/predict"""
np.save(input0_path, x0)
np.save(input1_path, x1)
input_str = 'x0=' + input0_path + '[x0];x1=' + input1_path
- feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
+ feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
+ input_str, '', '')
self.assertTrue(np.all(feed_dict['x0'] == x0))
self.assertTrue(np.all(feed_dict['x1'] == x1))
@@ -259,7 +261,8 @@ Method name is: tensorflow/serving/predict"""
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
np.savez(input_path, a=x0)
input_str = 'x=' + input_path + '[a];y=' + input_path
- feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
+ feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
+ input_str, '', '')
self.assertTrue(np.all(feed_dict['x'] == x0))
self.assertTrue(np.all(feed_dict['y'] == x0))
@@ -278,7 +281,8 @@ Method name is: tensorflow/serving/predict"""
pickle.dump(pkl2, f)
input_str = 'x=' + input_path0 + '[b];y=' + input_path1 + '[c];'
input_str += 'z=' + input_path2
- feed_dict = saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
+ feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
+ input_str, '', '')
self.assertTrue(np.all(feed_dict['x'] == pkl0['b']))
self.assertTrue(np.all(feed_dict['y'] == pkl1))
self.assertTrue(np.all(feed_dict['z'] == pkl2))
@@ -291,7 +295,7 @@ Method name is: tensorflow/serving/predict"""
input_expr_str = ('x1=np.ones([2,10]);x2=np.array([[1],[2],[3]]);'
'x3=np.mgrid[0:5,0:5];x4=[[3],[4]]')
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
- '', input_expr_str)
+ '', input_expr_str, '')
self.assertTrue(np.all(feed_dict['x1'] == x1))
self.assertTrue(np.all(feed_dict['x2'] == x2))
self.assertTrue(np.all(feed_dict['x3'] == x3))
@@ -305,7 +309,7 @@ Method name is: tensorflow/serving/predict"""
input_str = 'x0=' + input_path + '[a]'
input_expr_str = 'x1=np.ones([2,10])'
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
- input_str, input_expr_str)
+ input_str, input_expr_str, '')
self.assertTrue(np.all(feed_dict['x0'] == x0))
self.assertTrue(np.all(feed_dict['x1'] == x1))
@@ -317,7 +321,7 @@ Method name is: tensorflow/serving/predict"""
input_str = 'x0=' + input_path + '[a]'
input_expr_str = 'x0=np.ones([2,10])'
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
- input_str, input_expr_str)
+ input_str, input_expr_str, '')
self.assertTrue(np.all(feed_dict['x0'] == x1))
def testInputParserErrorNoName(self):
@@ -327,7 +331,7 @@ Method name is: tensorflow/serving/predict"""
np.savez(input_path, a=x0, b=x1)
input_str = 'x=' + input_path
with self.assertRaises(RuntimeError):
- saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
+ saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '')
def testInputParserErrorWrongName(self):
x0 = np.array([[1], [2]])
@@ -336,7 +340,22 @@ Method name is: tensorflow/serving/predict"""
np.savez(input_path, a=x0, b=x1)
input_str = 'x=' + input_path + '[c]'
with self.assertRaises(RuntimeError):
- saved_model_cli.load_inputs_from_input_arg_string(input_str, '')
+ saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '')
+
+ def testRunCommandInputExamples(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'regress_x_to_y', '--input_examples',
+ 'inputs=[{"x":[8.0],"x2":[5.0]}, {"x":[4.0],"x2":[3.0]}]', '--outdir',
+ output_dir
+ ])
+ saved_model_cli.run(args)
+ y_actual = np.load(os.path.join(output_dir, 'outputs.npy'))
+ y_expected = np.array([[6.0], [4.0]])
+ self.assertAllEqual(y_expected, y_actual)
def testRunCommandExistingOutdir(self):
self.parser = saved_model_cli.create_parser()
@@ -410,6 +429,42 @@ Method name is: tensorflow/serving/predict"""
with self.assertRaises(ValueError):
saved_model_cli.run(args)
+ def testRunCommandInputExamplesNotListError(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'regress_x_to_y', '--input_examples', 'inputs={"x":8.0,"x2":5.0}',
+ '--outdir', output_dir
+ ])
+ with self.assertRaisesRegexp(ValueError, 'must be a list'):
+ saved_model_cli.run(args)
+
+ def testRunCommandInputExamplesFeatureValueNotListError(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'regress_x_to_y', '--input_examples', 'inputs=[{"x":8.0,"x2":5.0}]',
+ '--outdir', output_dir
+ ])
+ with self.assertRaisesRegexp(ValueError, 'feature value must be a list'):
+ saved_model_cli.run(args)
+
+ def testRunCommandInputExamplesFeatureBadType(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'regress_x_to_y', '--input_examples', 'inputs=[{"x":[[1],[2]]}]',
+ '--outdir', output_dir
+ ])
+ with self.assertRaisesRegexp(ValueError, 'is not supported'):
+ saved_model_cli.run(args)
+
def testRunCommandOutputFileExistError(self):
self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 266f5563e0..0c69f8bf39 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -24,7 +24,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
@@ -101,19 +100,16 @@ class AdamOptimizer(optimizer.Optimizer):
self._beta2_t = None
self._epsilon_t = None
- # Variables to accumulate the powers of the beta parameters.
- # Created in _create_slots when we know the variables to optimize.
- self._beta1_power = None
- self._beta2_power = None
-
# Created in SparseApply if needed.
self._updated_lr = None
def _get_beta_accumulators(self):
- return self._beta1_power, self._beta2_power
-
- def _non_slot_variables(self):
- return self._get_beta_accumulators()
+ if context.in_graph_mode():
+ graph = ops.get_default_graph()
+ else:
+ graph = None
+ return (self._get_non_slot_variable("beta1_power", graph=graph),
+ self._get_non_slot_variable("beta2_power", graph=graph))
def _create_slots(self, var_list):
# Create the beta1 and beta2 accumulators on the same device as the first
@@ -121,19 +117,13 @@ class AdamOptimizer(optimizer.Optimizer):
# workers (these need to go on the same PS, otherwise some updates are
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
+ self._create_non_slot_variable(initial_value=self._beta1,
+ name="beta1_power",
+ colocate_with=first_var)
+ self._create_non_slot_variable(initial_value=self._beta2,
+ name="beta2_power",
+ colocate_with=first_var)
- create_new = self._beta1_power is None
- if not create_new and context.in_graph_mode():
- create_new = (self._beta1_power.graph is not first_var.graph)
-
- if create_new:
- with ops.colocate_with(first_var):
- self._beta1_power = variable_scope.variable(self._beta1,
- name="beta1_power",
- trainable=False)
- self._beta2_power = variable_scope.variable(self._beta2,
- name="beta2_power",
- trainable=False)
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
@@ -148,10 +138,11 @@ class AdamOptimizer(optimizer.Optimizer):
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.apply_adam(
var, m, v,
- math_ops.cast(self._beta1_power, var.dtype.base_dtype),
- math_ops.cast(self._beta2_power, var.dtype.base_dtype),
+ math_ops.cast(beta1_power, var.dtype.base_dtype),
+ math_ops.cast(beta2_power, var.dtype.base_dtype),
math_ops.cast(self._lr_t, var.dtype.base_dtype),
math_ops.cast(self._beta1_t, var.dtype.base_dtype),
math_ops.cast(self._beta2_t, var.dtype.base_dtype),
@@ -161,10 +152,11 @@ class AdamOptimizer(optimizer.Optimizer):
def _resource_apply_dense(self, grad, var):
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
+ beta1_power, beta2_power = self._get_beta_accumulators()
return training_ops.resource_apply_adam(
var.handle, m.handle, v.handle,
- math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
- math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
+ math_ops.cast(beta1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta2_power, grad.dtype.base_dtype),
math_ops.cast(self._lr_t, grad.dtype.base_dtype),
math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
@@ -172,8 +164,9 @@ class AdamOptimizer(optimizer.Optimizer):
grad, use_locking=self._use_locking)
def _apply_sparse_shared(self, grad, var, indices, scatter_add):
- beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
@@ -217,12 +210,11 @@ class AdamOptimizer(optimizer.Optimizer):
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):
- with ops.colocate_with(self._beta1_power):
- update_beta1 = self._beta1_power.assign(
- self._beta1_power * self._beta1_t,
- use_locking=self._use_locking)
- update_beta2 = self._beta2_power.assign(
- self._beta2_power * self._beta2_t,
- use_locking=self._use_locking)
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ with ops.colocate_with(beta1_power):
+ update_beta1 = beta1_power.assign(
+ beta1_power * self._beta1_t, use_locking=self._use_locking)
+ update_beta2 = beta2_power.assign(
+ beta2_power * self._beta2_t, use_locking=self._use_locking)
return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
name=name_scope)
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index ffb66abc4c..a521f1299e 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -174,8 +174,11 @@ class AdamOptimizerTest(test.TestCase):
opt = adam.AdamOptimizer()
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
opt_variables = opt.variables()
- self.assertIn(opt._beta1_power, opt_variables)
- self.assertIn(opt._beta2_power, opt_variables)
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertTrue(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
with ops.Graph().as_default():
# Shouldn't return non-slot variables from other graphs.
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 5054873bc1..b5d3e78797 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -176,7 +176,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
reader = load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()
- for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
+ for tensor_name_in_ckpt, current_var_or_name in sorted(
+ six.iteritems(assignment_map)):
var = None
# Check if this is Variable object or list of Variable objects (in case of
# partitioned variables).
@@ -233,7 +234,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if "/part_" in var_name:
var_name = var_name[:var_name.index("/part_")]
scope_variables.add(var_name)
- for var_name in scope_variables:
+ for var_name in sorted(scope_variables):
# Lookup name with specified prefix and suffix from current variable.
# If tensor_name given is '/' (root), don't use it for full name.
full_tensor_name = var_name[len(scopes):]
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
index 149d3eed41..3e4ac1dfff 100644
--- a/tensorflow/python/training/coordinator_test.py
+++ b/tensorflow/python/training/coordinator_test.py
@@ -85,8 +85,8 @@ class CoordinatorTest(test.TestCase):
self.assertFalse(coord.wait_for_stop(0.1))
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
- t = threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev))
+ t = threading.Thread(
+ target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev))
t.start()
self.assertFalse(coord.should_stop())
self.assertFalse(coord.wait_for_stop(0.01))
@@ -100,7 +100,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01,)),
threading.Thread(target=SleepABit, args=(0.02,)),
- threading.Thread(target=SleepABit, args=(0.01,))]
+ threading.Thread(target=SleepABit, args=(0.01,))
+ ]
for t in threads:
t.start()
coord.join(threads)
@@ -112,7 +113,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02, coord)),
- threading.Thread(target=SleepABit, args=(0.01, coord))]
+ threading.Thread(target=SleepABit, args=(0.01, coord))
+ ]
for t in threads:
t.start()
WaitForThreadsToRegister(coord, 3)
@@ -125,7 +127,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02,)),
- threading.Thread(target=SleepABit, args=(0.01, coord))]
+ threading.Thread(target=SleepABit, args=(0.01, coord))
+ ]
for t in threads:
t.start()
WaitForThreadsToRegister(coord, 2)
@@ -135,14 +138,17 @@ class CoordinatorTest(test.TestCase):
self.assertFalse(t.is_alive())
def testJoinGraceExpires(self):
+
def TestWithGracePeriod(stop_grace_period):
coord = coordinator.Coordinator()
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
threads = [
- threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev)),
- threading.Thread(target=SleepABit, args=(10.0,))]
+ threading.Thread(
+ target=StopOnEvent,
+ args=(coord, wait_for_stop_ev, has_stopped_ev)),
+ threading.Thread(target=SleepABit, args=(10.0,))
+ ]
for t in threads:
t.daemon = True
t.start()
@@ -150,6 +156,7 @@ class CoordinatorTest(test.TestCase):
has_stopped_ev.wait()
with self.assertRaisesRegexp(RuntimeError, "threads still running"):
coord.join(threads, stop_grace_period_secs=stop_grace_period)
+
TestWithGracePeriod(1e-10)
TestWithGracePeriod(0.002)
TestWithGracePeriod(1.0)
@@ -159,16 +166,16 @@ class CoordinatorTest(test.TestCase):
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
threads = [
- threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev)),
- threading.Thread(target=SleepABit, args=(10.0,))]
+ threading.Thread(
+ target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)),
+ threading.Thread(target=SleepABit, args=(10.0,))
+ ]
for t in threads:
t.daemon = True
t.start()
wait_for_stop_ev.set()
has_stopped_ev.wait()
- coord.join(
- threads, stop_grace_period_secs=1., ignore_live_threads=True)
+ coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True)
def testJoinRaiseReportExcInfo(self):
coord = coordinator.Coordinator()
@@ -180,7 +187,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"), False)),
threading.Thread(
target=RaiseOnEvent,
- args=(coord, ev_2, None, RuntimeError("Too late"), False))]
+ args=(coord, ev_2, None, RuntimeError("Too late"), False))
+ ]
for t in threads:
t.start()
@@ -199,7 +207,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"), True)),
threading.Thread(
target=RaiseOnEvent,
- args=(coord, ev_2, None, RuntimeError("Too late"), True))]
+ args=(coord, ev_2, None, RuntimeError("Too late"), True))
+ ]
for t in threads:
t.start()
@@ -214,9 +223,8 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None,
- errors_impl.OutOfRangeError(None, None, "First"),
- True))
- ]
+ errors_impl.OutOfRangeError(None, None, "First"), True))
+ ]
for t in threads:
t.start()
@@ -230,7 +238,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, ValueError("Clean stop"), True))
- ]
+ ]
for t in threads:
t.start()
@@ -247,7 +255,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"))),
threading.Thread(
target=RaiseOnEventUsingContextHandler,
- args=(coord, ev_2, None, RuntimeError("Too late")))]
+ args=(coord, ev_2, None, RuntimeError("Too late")))
+ ]
for t in threads:
t.start()
@@ -262,7 +271,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, RuntimeError("First"), True)),
- ]
+ ]
for t in threads:
t.start()
@@ -274,7 +283,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, RuntimeError("Second"), True)),
- ]
+ ]
for t in threads:
t.start()
with self.assertRaisesRegexp(RuntimeError, "Second"):
@@ -337,24 +346,29 @@ class LooperTest(test.TestCase):
def testTargetArgs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- args=(coord, n))
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, args=(coord, n))
coord.join([thread])
self.assertEqual(0, n[0])
def testTargetKwargs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- kwargs={"coord": coord, "n": n})
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, kwargs={
+ "coord": coord,
+ "n": n
+ })
coord.join([thread])
self.assertEqual(0, n[0])
def testTargetMixedArgs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- args=(coord,), kwargs={"n": n})
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, args=(coord,), kwargs={
+ "n": n
+ })
coord.join([thread])
self.assertEqual(0, n[0])
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index e34c759e89..43ed1ac170 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -187,7 +187,7 @@ def _zero_debias(unbiased_var, value, decay):
with variable_scope.variable_scope(
unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope:
with ops.colocate_with(unbiased_var):
- with ops.control_dependencies(None):
+ with ops.init_scope():
biased_initializer = init_ops.zeros_initializer(
dtype=unbiased_var.dtype)(unbiased_var.get_shape())
local_step_initializer = init_ops.zeros_initializer()
@@ -385,7 +385,7 @@ class ExponentialMovingAverage(object):
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
- with ops.control_dependencies(None):
+ with ops.init_scope():
if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(var,
var.initialized_value(),
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 56cf4d42ee..719b83e5ca 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
from tensorflow.python.util import nest
@@ -299,6 +300,7 @@ class Optimizer(object):
# Dictionary of slots.
# {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... }
self._slots = {}
+ self._non_slot_dict = {}
def get_name(self):
return self._name
@@ -512,7 +514,7 @@ class Optimizer(object):
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, _, v in converted_grads_and_vars],))
- with ops.control_dependencies(None):
+ with ops.init_scope():
self._create_slots([_get_variable_for(v) for v in var_list])
update_ops = []
with ops.name_scope(name, self._name) as name:
@@ -603,17 +605,32 @@ class Optimizer(object):
# Sort variables by name so that the return is deterministic.
return sorted(optimizer_variables, key=lambda v: v.name)
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ """Add an extra variable, not associated with a slot."""
+ if context.in_graph_mode():
+ graph = colocate_with.graph
+ else:
+ graph = None
+
+ key = (name, graph)
+ v = self._non_slot_dict.get(key, None)
+ if v is None:
+ with ops.colocate_with(colocate_with):
+ v = variable_scope.variable(initial_value, name=name, trainable=False)
+ self._non_slot_dict[key] = v
+
+ return v
+
+ def _get_non_slot_variable(self, name, graph=None):
+ return self._non_slot_dict.get((name, graph), None)
+
def _non_slot_variables(self):
"""Additional variables created by the `Optimizer`.
- This method should be overridden by child classes which create extra
- variables, so that `variables()` includes the `Optimizer`'s non-slot
- variables.
-
Returns:
A list or tuple of variables.
"""
- return []
+ return self._non_slot_dict.values()
def _assert_valid_dtypes(self, tensors):
"""Asserts tensors are all valid types (see `_valid_dtypes`).
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2c59b82ebe..4f3773c0fc 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1592,9 +1592,9 @@ class Saver(object):
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A string: path prefix used for the checkpoint files. If the saver is
- sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
- is the number of shards created.
+ A string: path prefix used for the checkpoint files. If checkpoint
+ format is V1 and the saver is sharded, this string ends with:
+ '-?????-of-nnnnn' where 'nnnnn' is the number of shards created.
If the saver is empty, returns None.
Raises:
@@ -1744,6 +1744,11 @@ class Saver(object):
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+ if (os.path.isfile(save_path) and
+ self._write_version != saver_pb2.SaverDef.V1):
+ raise ValueError("The specified path: %s is a file."
+ " Please specify only the path prefix"
+ " to the checkpoint files." % save_path)
logging.info("Restoring parameters from %s", save_path)
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index 297284f80c..fff17402e2 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -286,8 +286,9 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
global_step = variables.Variable(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
opt_variables = opt.variables()
- self.assertIn(opt._opt._beta1_power, opt_variables)
- self.assertIn(opt._opt._beta2_power, opt_variables)
+ beta1_power, beta2_power = opt._opt._get_beta_accumulators()
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
if __name__ == "__main__":
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 07382d93df..270d96a3c7 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Functions for Python 2 vs. 3 compatibility.
## Conversion routines
@@ -21,6 +20,7 @@ In addition to the functions below, `as_str` converts an object to a `str`.
@@as_bytes
@@as_text
@@as_str_any
+@@path_to_str
## Types
The compatibility module also provides the following types:
@@ -108,17 +108,29 @@ def as_str_any(value):
return str(value)
+def path_to_str(path):
+ """Returns the file system path representation of a `PathLike` object, else as it is.
+
+ Args:
+ path: An object that can be converted to path representation.
+
+ Returns:
+ A `str` object.
+ """
+ if hasattr(path, '__fspath__'):
+ path = as_str_any(path.__fspath__())
+ return path
+
+
# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we
# need to check them specifically. The same goes from Real and Complex.
integral_types = (_numbers.Integral, _np.integer)
real_types = (_numbers.Real, _np.integer, _np.floating)
complex_types = (_numbers.Complex, _np.number)
-
# Either bytes or text.
bytes_or_text_types = (bytes, _six.text_type)
-
_allowed_symbols = [
'as_str',
'bytes_or_text_types',
diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py
new file mode 100644
index 0000000000..a299b2fc3c
--- /dev/null
+++ b/tensorflow/python/util/compat_internal.py
@@ -0,0 +1,34 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions for Python 2 vs. 3 compatibility that are private to TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def path_to_str(path):
+ """Returns the file system path representation of a `PathLike` object, else as it is.
+
+ Args:
+ path: An object that can be converted to path representation.
+
+ Returns:
+ A `str` object.
+ """
+ if hasattr(path, "__fspath__"):
+ path = as_str_any(path.__fspath__())
+ return path
diff --git a/tensorflow/python/util/kernel_registry.h b/tensorflow/python/util/kernel_registry.h
index c00b60d91b..1ba76f020b 100644
--- a/tensorflow/python/util/kernel_registry.h
+++ b/tensorflow/python/util/kernel_registry.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Functions for getting information about kernels registered in the binary.
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
+#ifndef TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
+#define TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/types.h"
@@ -31,4 +31,4 @@ string TryFindKernelClass(const string& serialized_node_def);
} // namespace swig
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
+#endif // TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 4ce871de72..874df3d108 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -47,10 +47,25 @@ def _sorted(dict_):
raise TypeError("nest only supports dicts with sortable keys.")
-def _is_namedtuple(instance):
- """Returns True iff `instance` is a `namedtuple`."""
+def _is_namedtuple(instance, strict=False):
+ """Returns True iff `instance` is a `namedtuple`.
+
+ Args:
+ instance: An instance of a Python object.
+ strict: If True, `instance` is considered to be a `namedtuple` only if
+ it is a "plain" namedtuple. For instance, a class inheriting
+ from a `namedtuple` will be considered to be a `namedtuple`
+ iff `strict=False`.
+
+ Returns:
+ True if `instance` is a `namedtuple`.
+ """
+ # Attemp to limit the test to plain namedtuple (not stuff inheriting from it).
+ if not isinstance(instance, tuple):
+ return False
+ if strict and instance.__class__.__base__ != tuple:
+ return False
return (
- isinstance(instance, tuple) and
hasattr(instance, "_fields") and
isinstance(instance._fields, _collections.Sequence) and
all(isinstance(f, _six.string_types) for f in instance._fields))
@@ -140,8 +155,37 @@ def flatten(nest):
return _pywrap_tensorflow.Flatten(nest)
+def _same_namedtuples(nest1, nest2):
+ """Returns True if the two namedtuples have the same name and fields."""
+ if nest1._fields != nest2._fields:
+ return False
+ if nest1.__class__.__name__ != nest2.__class__.__name__:
+ return False
+ return True
+
+
def _recursive_assert_same_structure(nest1, nest2, check_types):
- """Helper function for `assert_same_structure`."""
+ """Helper function for `assert_same_structure`.
+
+ See `assert_same_structure` for further information about namedtuples.
+
+ Args:
+ nest1: An arbitrarily nested structure.
+ nest2: An arbitrarily nested structure.
+ check_types: If `True` (default) types of sequences are checked as
+ well, including the keys of dictionaries. If set to `False`, for example
+ a list and a tuple of objects will look the same if they have the same
+ size. Note that namedtuples with identical name and fields are always
+ considered to have the same shallow structure.
+
+ Returns:
+ True if `nest1` and `nest2` have the same structure.
+
+ Raises:
+ ValueError: If the two structure don't have the same nested structre.
+ TypeError: If the two structure don't have the same sequence type.
+ ValueError: If the two dictionaries don't have the same set of keys.
+ """
is_sequence_nest1 = is_sequence(nest1)
if is_sequence_nest1 != is_sequence(nest2):
raise ValueError(
@@ -154,11 +198,21 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
if check_types:
type_nest1 = type(nest1)
type_nest2 = type(nest2)
- if type_nest1 != type_nest2:
- raise TypeError(
- "The two structures don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
+
+ # Duck-typing means that nest should be fine with two different namedtuples
+ # with identical name and fields.
+ if _is_namedtuple(nest1, True) and _is_namedtuple(nest2, True):
+ if not _same_namedtuples(nest1, nest2):
+ raise TypeError(
+ "The two namedtuples don't have the same sequence type. First "
+ "structure has type %s, while second structure has type %s."
+ % (type_nest1, type_nest2))
+ else:
+ if type_nest1 != type_nest2:
+ raise TypeError(
+ "The two structures don't have the same sequence type. First "
+ "structure has type %s, while second structure has type %s."
+ % (type_nest1, type_nest2))
if isinstance(nest1, dict):
keys1 = set(_six.iterkeys(nest1))
@@ -178,13 +232,24 @@ def _recursive_assert_same_structure(nest1, nest2, check_types):
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
+ Note that namedtuples with identical name and fields are always considered
+ to have the same shallow structure (even with `check_types=True`).
+ For intance, this code will print `True`:
+
+ ```python
+ def nt(a, b):
+ return collections.namedtuple('foo', 'a b')(a, b)
+ print(assert_same_structure(nt(0, 1), nt(2, 3)))
+ ```
+
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
check_types: if `True` (default) types of sequences are checked as
well, including the keys of dictionaries. If set to `False`, for example
a list and a tuple of objects will look the same if they have the same
- size.
+ size. Note that namedtuples with identical name and fields are always
+ considered to have the same shallow structure.
Raises:
ValueError: If the two structures do not have the same number of elements or
@@ -354,6 +419,8 @@ def map_structure(func, *structure, **check_types_dict):
`True` (default) the types of iterables within the structures have to be
same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
exception). To allow this set this argument to `False`.
+ Note that namedtuples with identical name and fields are always
+ considered to have the same shallow structure.
Returns:
A new structure with the same arity as `structure`, whose values correspond
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 4906649f01..6bec397db5 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -258,6 +258,36 @@ class NestTest(test.TestCase):
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})
+ same_name_type_0 = collections.namedtuple("same_name", ("a", "b"))
+ same_name_type_1 = collections.namedtuple("same_name", ("a", "b"))
+ nest.assert_same_structure(same_name_type_0(0, 1), same_name_type_1(2, 3))
+
+ # This assertion is expected to pass: two namedtuples with the same
+ # name and field names are considered to be identical.
+ same_name_type_2 = collections.namedtuple("same_name_1", ("x", "y"))
+ same_name_type_3 = collections.namedtuple("same_name_1", ("x", "y"))
+ nest.assert_same_structure(
+ same_name_type_0(same_name_type_2(0, 1), 2),
+ same_name_type_1(same_name_type_3(2, 3), 4))
+
+ expected_message = "The two structures don't have the same.*"
+ with self.assertRaisesRegexp(ValueError, expected_message):
+ nest.assert_same_structure(same_name_type_0(0, same_name_type_1(1, 2)),
+ same_name_type_1(same_name_type_0(0, 1), 2))
+
+ same_name_type_1 = collections.namedtuple("not_same_name", ("a", "b"))
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ same_name_type_0(0, 1), same_name_type_1(2, 3))
+
+ same_name_type_1 = collections.namedtuple("same_name", ("x", "y"))
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ same_name_type_0(0, 1), same_name_type_1(2, 3))
+
+ class SameNamedType1(collections.namedtuple("same_name", ("a", "b"))):
+ pass
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ same_name_type_0(0, 1), SameNamedType1(2, 3))
+
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index d14e710388..c4168f7b1a 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -117,7 +117,16 @@ def getdoc(object): # pylint: disable=redefined-builtin
def getfile(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.getfile."""
- return _inspect.getfile(tf_decorator.unwrap(object)[1])
+ unwrapped_object = tf_decorator.unwrap(object)[1]
+
+ # Work around for the case when object is a stack frame
+ # and only .pyc files are used. In this case, getfile
+ # might return incorrect path. So, we get the path from f_globals
+ # instead.
+ if (hasattr(unwrapped_object, 'f_globals') and
+ '__file__' in unwrapped_object.f_globals):
+ return unwrapped_object.f_globals['__file__']
+ return _inspect.getfile(unwrapped_object)
def getmembers(object, predicate=None): # pylint: disable=redefined-builtin
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 493d26b497..2af71dc753 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Functions for getting information about kernels registered in the binary.
-#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_
-#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_
+#ifndef TENSORFLOW_PYTHON_UTIL_UTIL_H_
+#define TENSORFLOW_PYTHON_UTIL_UTIL_H_
#include <Python.h>
@@ -71,4 +71,4 @@ void RegisterSequenceClass(PyObject* sequence_class);
} // namespace swig
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_UTIL_H_
+#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
index f35542e18f..933c103f52 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -232,7 +232,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
result = StringToDriverVersion(version);
}
#else
-#if !defined(PLATFORM_WINDOWS) && !defined(NVIDIA_TEGRA)
+#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA)
// Callback used when iterating through DSOs. Looks for the driver-interfacing
// DSO and yields its version number into the callback data, when found.
auto iterate_phdr =
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index 5210a81092..d71938634d 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -96,10 +96,18 @@ string GetCudnnVersion() { return TF_CUDNN_VERSION; }
}
/* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) {
+#if defined(ANDROID_TEGRA)
+ // On Android devices the CUDA version number is not added to the library name.
+ return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName(
+ "cupti", ""),
+ GetCudaCuptiLibraryPath()),
+ dso_handle);
+#else
return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName(
"cupti", GetCudaVersion()),
GetCudaCuptiLibraryPath()),
dso_handle);
+#endif
}
static mutex& GetRpathMutex() {
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 838b1218a4..b33f32fdfb 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -11,6 +11,10 @@ load(
"if_static",
)
load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda",
"cuda_default_copts",
@@ -197,6 +201,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
"-fno-exceptions",
"-ftemplate-depth=900"])
+ if_cuda(["-DGOOGLE_CUDA=1"])
+ + if_tensorrt(["-DGOOGLE_TENSORRT=1"])
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",])
+ if_android_arm(["-mfpu=neon"])
+ if_linux_x86_64(["-msse3"])
@@ -258,6 +263,8 @@ def _rpath_linkopts(name):
clean_dep("//tensorflow:darwin"): [
"-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),),
],
+ clean_dep("//tensorflow:windows"): [],
+ clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),),
],
@@ -289,6 +296,7 @@ def tf_cc_shared_object(
"-Wl,-install_name,@rpath/" + name.split("/")[-1],
],
"//conditions:default": [
+ "-Wl,-soname," + name.split("/")[-1],
],
}),
**kwargs)
@@ -600,6 +608,8 @@ def tf_cc_test(name,
"//tensorflow:android": [
"-pie",
],
+ clean_dep("//tensorflow:windows"): [],
+ clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-lpthread",
"-lm"
@@ -861,9 +871,11 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
When the library is built with --config=cuda:
- - both deps and cuda_deps are used as dependencies
- - the cuda runtime is added as a dependency (if necessary)
- - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts
+ - Both deps and cuda_deps are used as dependencies.
+ - The cuda runtime is added as a dependency (if necessary).
+ - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
+ - In addition, when the library is also built with TensorRT enabled, it
+ additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
Args:
- cuda_deps: BUILD dependencies which will be linked if and only if:
@@ -882,7 +894,8 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers"
]),
- copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]),
+ copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs)
register_extension_info(
@@ -1246,6 +1259,8 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
"//conditions:default": [
"-lm",
],
+ clean_dep("//tensorflow:windows"): [],
+ clean_dep("//tensorflow:windows_msvc"): [],
clean_dep("//tensorflow:darwin"): [],
}),)
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index fa0f9b59aa..d110316395 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -46,11 +46,37 @@ genrule(
"api/bitwise/__init__.py",
"api/contrib/__init__.py",
"api/contrib/stat_summarizer/__init__.py",
+ "api/distributions/__init__.py",
+ "api/distributions/bijectors/__init__.py",
+ "api/errors/__init__.py",
"api/image/__init__.py",
"api/linalg/__init__.py",
"api/nn/__init__.py",
"api/spectral/__init__.py",
"api/train/__init__.py",
+ "api/app/__init__.py",
+ "api/gfile/__init__.py",
+ "api/graph_util/__init__.py",
+ "api/keras/__init__.py",
+ "api/keras/backend/__init__.py",
+ "api/keras/datasets/__init__.py",
+ "api/keras/datasets/boston_housing/__init__.py",
+ "api/keras/datasets/cifar10/__init__.py",
+ "api/keras/datasets/cifar100/__init__.py",
+ "api/keras/datasets/imdb/__init__.py",
+ "api/keras/datasets/mnist/__init__.py",
+ "api/keras/datasets/reuters/__init__.py",
+ "api/keras/utils/__init__.py",
+ "api/logging/__init__.py",
+ "api/resource_loader/__init__.py",
+ "api/sysconfig/__init__.py",
+ "api/test/__init__.py",
+ "api/initializers/__init__.py",
+ "api/keras/initializers/__init__.py",
+ "api/metrics/__init__.py",
+ "api/nn/rnn_cell/__init__.py",
+ "api/sets/__init__.py",
+ "api/summary/__init__.py",
],
cmd = "$(location create_python_api) $(OUTS)",
tools = ["create_python_api"],
@@ -60,4 +86,7 @@ py_library(
name = "python_api",
srcs = [":python_api_gen"],
srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib:contrib_py", # keep
+ ],
)
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index aab856b723..1557314939 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -31,6 +31,7 @@ from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
+_CONTRIB_IMPORT = 'from tensorflow import contrib'
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
This file is MACHINE GENERATED! Do not edit.
@@ -50,11 +51,17 @@ def format_import(source_module_name, source_name, dest_name):
Returns:
An import statement string.
"""
- if source_name == dest_name:
- return 'from %s import %s' % (source_module_name, source_name)
+ if source_module_name:
+ if source_name == dest_name:
+ return 'from %s import %s' % (source_module_name, source_name)
+ else:
+ return 'from %s import %s as %s' % (
+ source_module_name, source_name, dest_name)
else:
- return 'from %s import %s as %s' % (
- source_module_name, source_name, dest_name)
+ if source_name == dest_name:
+ return 'import %s' % source_name
+ else:
+ return 'import %s as %s' % (source_name, dest_name)
def get_api_imports():
@@ -74,6 +81,9 @@ def get_api_imports():
# Only look at tensorflow modules.
if not module or 'tensorflow.' not in module.__name__:
continue
+ # Do not generate __init__.py files for contrib modules for now.
+ if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
+ continue
for module_contents_name in dir(module):
attr = getattr(module, module_contents_name)
@@ -151,21 +161,28 @@ def create_api_files(output_files):
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- # Add imports to output files.
module_imports = get_api_imports()
+ module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib.
+
+ # Add imports to output files.
missing_output_files = []
for module, exports in module_imports.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
- missing_output_files.append(module)
+ module_without_tf = module[len('tf.'):]
+ module_file_path = '"api/%s/__init__.py"' % (
+ module_without_tf.replace('.', '/'))
+ missing_output_files.append(module_file_path)
continue
with open(module_name_to_file_path[module], 'w') as fp:
fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports))
if missing_output_files:
raise ValueError(
- 'Missing outputs for python_api_gen genrule:\n%s' %
- ',\n'.join(missing_output_files))
+ 'Missing outputs for python_api_gen genrule:\n%s.'
+ 'Make sure all required outputs are in the '
+ 'tensorflow/tools/api/generator/BUILD file.' %
+ ',\n'.join(sorted(missing_output_files)))
def main(output_files):
diff --git a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt
index ccc6031400..bab480ff9b 100644
--- a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.compat.pbtxt
@@ -32,4 +32,8 @@ tf_module {
name: "as_text"
argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], "
}
+ member_method {
+ name: "path_to_str"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 46d5957057..efc441ae2f 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 439e87375b..20ce879870 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index f79a8be3f6..73211aaf8b 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
index c466dcb4c2..27a159639d 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
index d0bf043754..76f527f796 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
@@ -20,7 +20,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
index cb9e95588d..c45318b98a 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
index 637f19ba26..04a2aa080d 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt
index 018e8c909a..24a58fb118 100644
--- a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt
@@ -49,6 +49,10 @@ tf_module {
argspec: "args=[\'key\', \'shape\', \'default_value\', \'dtype\', \'normalizer_fn\'], varargs=None, keywords=None, defaults=[\'(1,)\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
+ name: "shared_embedding_columns"
+ argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], "
+ }
+ member_method {
name: "weighted_categorical_column"
argspec: "args=[\'categorical_column\', \'weight_feature_key\', \'dtype\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\"], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index f32353c957..baedf596e8 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -169,12 +169,20 @@ tf_module {
argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "rgb_to_yiq"
+ argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "rgb_to_yuv"
+ argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "rot90"
argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "sample_distorted_bounding_box"
- argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "total_variation"
@@ -184,4 +192,12 @@ tf_module {
name: "transpose_image"
argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
}
+ member_method {
+ name: "yiq_to_rgb"
+ argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "yuv_to_rgb"
+ argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 7fe3e2db09..2bf584fa29 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -160,15 +160,15 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
}
member_method {
name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
}
member_method {
name: "from_config"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "predict_on_batch"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt
new file mode 100644
index 0000000000..42cb914450
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt
@@ -0,0 +1,23 @@
+path: "tensorflow.keras.applications.densenet"
+tf_module {
+ member_method {
+ name: "DenseNet121"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet169"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet201"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "decode_predictions"
+ argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+ }
+ member_method {
+ name: "preprocess_input"
+ argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt
new file mode 100644
index 0000000000..cd75b87540
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.applications.nasnet"
+tf_module {
+ member_method {
+ name: "NASNetLarge"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "NASNetMobile"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "decode_predictions"
+ argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+ }
+ member_method {
+ name: "preprocess_input"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
index daeb5aad41..9fc086eb8e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.keras.applications"
tf_module {
member {
+ name: "densenet"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "inception_resnet_v2"
mtype: "<type \'module\'>"
}
@@ -13,6 +17,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "nasnet"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "resnet50"
mtype: "<type \'module\'>"
}
@@ -29,6 +37,18 @@ tf_module {
mtype: "<type \'module\'>"
}
member_method {
+ name: "DenseNet121"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet169"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet201"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
name: "InceptionResNetV2"
argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
}
@@ -41,6 +61,14 @@ tf_module {
argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
}
member_method {
+ name: "NASNetLarge"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "NASNetMobile"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
name: "ResNet50"
argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
index 44fbe0f7a0..ba2d083a75 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
index 8719c07ca3..d4c85a4519 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'schedule\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'schedule\', \'verbose\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
index ef08f9b20f..bda31751d4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'seed\', \'test_split\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'113\', \'0.2\'], "
+ argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
index 8b1c17e9da..ff962876b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
@@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
+ argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
index 6b3ed1e9af..2da4a13067 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
@@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
+ argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
index a32151e22f..770a107b66 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 46b1713196..0ce42b706e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
index 9bfaf27562..b371ad148c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 2b8ac4f1f4..2f5e65a0c5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -123,7 +123,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
@@ -131,7 +131,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
index c9a0b88725..ff08def0a0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index b847e224d6..6db22ca032 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -128,7 +128,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index d898c54627..11e05f884d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.Conv3DTranspose"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv3DTranspose\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index a7001bbe34..58724a1e16 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.Convolution3DTranspose"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv3DTranspose\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3DTranspose\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.Conv3D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
index 86578d958e..07d3f023e5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
index 348012dcde..92b9760d53 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
index 0419251083..83c528b401 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 337e85e812..b329f1c46b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 1357dc0f0d..d0f6d2a14f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -183,7 +183,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -195,7 +195,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index b71a08f6c3..57596badf1 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
index a01a6067ef..3829353cc3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 0dbbdf2838..3b171b137a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 964ef89c2e..0036d6805b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -187,7 +187,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -199,7 +199,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 6a7b23c540..8134fb7386 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 324745e5a3..c5d4523009 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index e12ae05054..bcbed9241b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
index 9e889ca863..ff0db15f19 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
index 932680941d..1d3f33f045 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
index db644f958f..c86bc49b22 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
index 74fa1db020..b29f65d79d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -94,7 +94,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
@@ -118,7 +118,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -130,7 +130,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
new file mode 100644
index 0000000000..dd67b76523
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -0,0 +1,186 @@
+path: "tensorflow.keras.layers.SeparableConv1D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index a6d9b57c88..5d898fb2bd 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SeparableConv2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv2D\'>"
- is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
new file mode 100644
index 0000000000..bf62c095e7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -0,0 +1,186 @@
+path: "tensorflow.keras.layers.SeparableConvolution1D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 551d695379..c758d87993 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SeparableConvolution2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv2D\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv2D\'>"
- is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 3414810db4..6e3cde3e3e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index cf34034ef0..b875898a81 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -175,7 +175,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -187,7 +187,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
new file mode 100644
index 0000000000..ee4b2fa39e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
@@ -0,0 +1,183 @@
+path: "tensorflow.keras.layers.Softmax"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.Softmax\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index b76499658d..db9f90caef 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -118,7 +118,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index 2376d815a6..ef31c5443e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index fe336c4be5..088c8e88e2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -293,10 +293,18 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "SeparableConv1D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SeparableConv2D"
mtype: "<type \'type\'>"
}
member {
+ name: "SeparableConvolution1D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SeparableConvolution2D"
mtype: "<type \'type\'>"
}
@@ -309,6 +317,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Softmax"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SpatialDropout1D"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index d239098b0b..0b816b5863 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -160,15 +160,15 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
}
member_method {
name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
}
member_method {
name: "from_config"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "predict_on_batch"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
index ed040c1586..32667cf31e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
index a24651429a..efca59e8e4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
index a0d978fded..5546e2067a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
index 1b70c93ad5..aaa54a1060 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
index b49dbe5cf8..1fada7fd9c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.004\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.004\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index c8860d80d4..fd3f97f35d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
index 5bc8c40120..ce91caa1af 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\'], "
+ argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\', \'oov_token\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\', \'None\'], "
}
member_method {
name: "fit_on_sequences"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
new file mode 100644
index 0000000000..05799ecfc9
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
@@ -0,0 +1,144 @@
+path: "tensorflow.layers.SeparableConv1D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
index 4d91ab1d8c..c2aeb35c46 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.layers.SeparableConv2D"
tf_class {
is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv2D\'>"
- is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
index c45d6e6c05..59134f8489 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
@@ -69,6 +69,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "SeparableConv1D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SeparableConv2D"
mtype: "<type \'type\'>"
}
@@ -137,6 +141,10 @@ tf_module {
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "
}
member_method {
+ name: "separable_conv1d"
+ argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
name: "separable_conv2d"
argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'1\', \'None\', \'True\', \'None\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 8ce022e454..455590d866 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -262,7 +262,7 @@ tf_module {
}
member_method {
name: "sampled_softmax_loss"
- argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\'], "
+ argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], "
}
member_method {
name: "selu"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index f61a5a28e3..97edf245f6 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -88,6 +88,10 @@ tf_class {
name: "weights"
mtype: "<type \'property\'>"
}
+ member {
+ name: "wrapped_cell"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\', \'dropout_state_filter_visitor\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 35917e94ad..db1ed42185 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1161,6 +1161,10 @@ tf_module {
argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
+ name: "histogram_fixed_width_bins"
+ argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "identity"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt
index 326e077d39..871ebb5247 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.pbtxt
@@ -50,7 +50,7 @@ tf_module {
}
member_method {
name: "merge_all"
- argspec: "args=[\'key\'], varargs=None, keywords=None, defaults=[\'summaries\'], "
+ argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'summaries\', \'None\'], "
}
member_method {
name: "scalar"
diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD
index caa6629c49..6ed2594e6a 100644
--- a/tensorflow/tools/benchmark/BUILD
+++ b/tensorflow/tools/benchmark/BUILD
@@ -61,10 +61,11 @@ tf_cc_test(
# This binary may be built for either desktop or Android.
# A typical Android build command will look like the following:
-# bazel build -c opt tensorflow/core:android_tensorflow_lib \
+# bazel build tensorflow/core:android_tensorflow_lib \
# --crosstool_top=//external:android/crosstool \
# --cpu=armeabi-v7a \
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
+# --config monolithic
tf_cc_binary(
name = "benchmark_model",
testonly = 1,
diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md
index ca0da2d41b..e64af2bfe1 100644
--- a/tensorflow/tools/benchmark/README.md
+++ b/tensorflow/tools/benchmark/README.md
@@ -17,6 +17,7 @@ bazel build -c opt \
--crosstool_top=//external:android/crosstool \
--cpu=armeabi-v7a \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
+ --config monolithic \
tensorflow/tools/benchmark:benchmark_model
```
diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh
index 26713dded8..9b3ff0cba7 100755
--- a/tensorflow/tools/ci_build/builds/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh
@@ -51,8 +51,8 @@ function build_libtensorflow_tarball() {
rm -rf ${DIR}
TARBALL_SUFFIX="${1}"
- BAZEL="bazel --bazelrc ./tensorflow/tools/ci_build/install/.bazelrc"
- BAZEL_OPTS="-c opt"
+ BAZEL_OPTS="-c opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0"
+ export CC_OPT_FLAGS='-mavx'
if [ "${TF_NEED_CUDA}" == "1" ]; then
BAZEL_OPTS="${BAZEL_OPTS} --config=cuda"
fi
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index b728c878da..27fa1b89ce 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -26,6 +26,8 @@
SCRIPT_DIR=$( cd ${0%/*} && pwd -P )
source "${SCRIPT_DIR}/builds/builds_common.sh"
+ROOT_DIR=$( cd "$SCRIPT_DIR/../../.." && pwd -P )
+
# Helper functions
die() {
echo $@
@@ -175,7 +177,13 @@ do_pylint() {
echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s"
echo ""
- grep -E '(\[E|\[W0311|\[W0312)' ${OUTPUT_FILE} > ${ERRORS_FILE}
+ # Report only what we care about
+ # Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html
+ # E: all errors
+ # W0311 bad-indentation
+ # W0312 mixed-indentation
+ # C0330 bad-continuation
+ grep -E '(\[E|\[W0311|\[W0312|\[C0330)' ${OUTPUT_FILE} > ${ERRORS_FILE}
N_ERRORS=0
while read -r LINE; do
@@ -418,15 +426,8 @@ do_bazel_nobuild() {
}
do_pip_smoke_test() {
- BUILD_CMD="bazel build ${BAZEL_FLAGS} //tensorflow/tools/pip_package:pip_smoke_test"
- ${BUILD_CMD}
- cmd_status \
- "Pip smoke test has failed. Please make sure any new TensorFlow are added to the tensorflow/tools/pip_package:build_pip_package dependencies."
-
- RUN_CMD="bazel-bin/tensorflow/tools/pip_package/pip_smoke_test"
- ${RUN_CMD}
- cmd_status \
- "The pip smoke test failed."
+ cd "$ROOT_DIR/tensorflow/tools/pip_package"
+ python pip_smoke_test.py
}
do_code_link_check() {
@@ -500,20 +501,23 @@ do_clang_format_check() {
}
do_check_load_py_test() {
- BUILD_CMD="bazel build ${BAZEL_FLAGS} //tensorflow/tools/pip_package:check_load_py_test"
- ${BUILD_CMD}
- cmd_status \
- "check_load_py_test failed to build."
+ cd "$ROOT_DIR/tensorflow/tools/pip_package"
+ python check_load_py_test.py
+}
- BUILD_CMD="bazel-bin/tensorflow/tools/pip_package/check_load_py_test"
- ${BUILD_CMD}
- cmd_status \
- "check_load_py_test failed."
+do_cmake_python_sanity() {
+ cd "$ROOT_DIR/tensorflow/contrib/cmake"
+ python -m unittest -v python_sanity_test
+}
+
+do_check_futures_test() {
+ cd "$ROOT_DIR/tensorflow/tools/test"
+ python check_futures_test.py
}
# Supply all sanity step commands and descriptions
-SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check")
-SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links")
+SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_cmake_python_sanity")
+SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency")
INCREMENTAL_FLAG=""
DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp"
@@ -548,7 +552,10 @@ while [[ ${COUNTER} -lt "${#SANITY_STEPS[@]}" ]]; do
"${SANITY_STEPS[COUNTER]} (${SANITY_STEPS_DESC[COUNTER]}) ==="
echo ""
+ # subshell: don't leak variables or changes of working directory
+ (
${SANITY_STEPS[COUNTER]} ${INCREMENTAL_FLAG}
+ )
RESULT=$?
if [[ ${RESULT} != "0" ]]; then
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index e3e6b2f316..51e10f81f8 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -26,12 +26,13 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
# Only running cc tests, python version does not matter.
export PYTHON_BIN_PATH=`which python`
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc,java -k \
- --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \
--test_output=errors -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
index 5110d52f31..ea14848b1a 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
@@ -26,11 +26,12 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python2`
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
- --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \
--test_output=errors -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index df6016504c..6d017c8a1f 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -26,12 +26,13 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \
- --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \
--test_output=errors -- \
//tensorflow/contrib/... \
-//tensorflow/contrib/lite/... \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
index ea9e102936..a9accb9dd5 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
@@ -26,11 +26,12 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
- --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \
--test_output=errors -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
index df196f829c..02224d8e9d 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7
@@ -35,6 +36,6 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
- --build_tests_only --test_output=errors --local_test_jobs=8 \
+ --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
index abd256a895..0367a53d14 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7
@@ -35,6 +36,6 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
- --build_tests_only --test_output=errors --local_test_jobs=8 \
+ --build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
index ddaaddc917..509ee38ec4 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh
@@ -27,11 +27,12 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=$(which python2)
yes "" | $PYTHON_BIN_PATH configure.py
which bazel
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \
--test_timeout 300,450,1200,3600 \
- --test_size_filters=small,medium \
+ --test_size_filters=small,medium --config=opt \
--jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \
//tensorflow/contrib/... -//tensorflow/contrib/lite/...
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
index e026dcd08f..0554713670 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
@@ -27,11 +27,12 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=$(which python2)
yes "" | $PYTHON_BIN_PATH configure.py
which bazel
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \
- --test_timeout 300,450,1200,3600 \
+ --test_timeout 300,450,1200,3600 --config=opt \
--test_size_filters=small,medium \
--jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index d2a63e5d66..347d0769a9 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -25,19 +25,19 @@
# pylint: disable=superfluous-parens
import argparse
-import fileinput
import os
import re
import subprocess
import time
-# File parameters
+# File parameters.
TF_SRC_DIR = "tensorflow"
VERSION_H = "%s/core/public/version.h" % TF_SRC_DIR
SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR
README_MD = "./README.md"
DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR
GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR
+CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR
RELEVANT_FILES = [TF_SRC_DIR,
VERSION_H,
SETUP_PY,
@@ -45,17 +45,11 @@ RELEVANT_FILES = [TF_SRC_DIR,
DEVEL_DOCKERFILE,
GPU_DEVEL_DOCKERFILE]
-# Version type parameters
+# Version type parameters.
NIGHTLY_VERSION = 1
REGULAR_VERSION = 0
-def replace_line(old_line, new_line, filename):
- """Replace a line in a file."""
- for line in fileinput.input(filename, inplace=True):
- print(line.rstrip().replace(old_line, new_line))
-
-
def check_existence(filename):
"""Check the existence of file or dir."""
if not os.path.exists(filename):
@@ -69,9 +63,12 @@ def check_all_files():
check_existence(file_name)
-def replace_with_sed(query, filename):
+def replace_string_in_line(search, replace, filename):
"""Replace with sed when regex is required."""
- subprocess.check_call(['sed', '-i', '-r', '-e', query, filename])
+ with open(filename, "r") as source:
+ content = source.read()
+ with open(filename, "w") as source:
+ source.write(re.sub(search, replace, content))
class Version(object):
@@ -125,13 +122,13 @@ class Version(object):
Raises:
RuntimeError: If the version string is not valid.
"""
- # Check validity of new version string
+ # Check validity of new version string.
if not re.search(r"[0-9]+\.[0-9]+\.[a-zA-Z0-9]+", string):
raise RuntimeError("Invalid version string: %s" % string)
major, minor, extension = string.split(".", 2)
- # Isolate patch and identifier string if identifier string exists
+ # Isolate patch and identifier string if identifier string exists.
extension_split = extension.split("-", 1)
patch = extension_split[0]
if len(extension_split) == 2:
@@ -154,7 +151,7 @@ def get_current_semver_version():
core/public/version.h
"""
- # Get current version information
+ # Get current version information.
version_file = open(VERSION_H, "r")
for line in version_file:
major_match = re.search("^#define TF_MAJOR_VERSION ([0-9]+)", line)
@@ -185,32 +182,33 @@ def get_current_semver_version():
def update_version_h(old_version, new_version):
"""Update tensorflow/core/public/version.h."""
- replace_line("#define TF_MAJOR_VERSION %s" % old_version.major,
- "#define TF_MAJOR_VERSION %s" % new_version.major, VERSION_H)
- replace_line("#define TF_MINOR_VERSION %s" % old_version.minor,
- "#define TF_MINOR_VERSION %s" % new_version.minor, VERSION_H)
- replace_line("#define TF_PATCH_VERSION %s" % old_version.patch,
- "#define TF_PATCH_VERSION %s" % new_version.patch, VERSION_H)
- replace_line("#define TF_VERSION_SUFFIX \"%s\"" %
- old_version.identifier_string,
- "#define TF_VERSION_SUFFIX \"%s\""
- % new_version.identifier_string,
- VERSION_H)
+ replace_string_in_line("#define TF_MAJOR_VERSION %s" % old_version.major,
+ "#define TF_MAJOR_VERSION %s" % new_version.major,
+ VERSION_H)
+ replace_string_in_line("#define TF_MINOR_VERSION %s" % old_version.minor,
+ "#define TF_MINOR_VERSION %s" % new_version.minor,
+ VERSION_H)
+ replace_string_in_line("#define TF_PATCH_VERSION %s" % old_version.patch,
+ "#define TF_PATCH_VERSION %s" % new_version.patch,
+ VERSION_H)
+ replace_string_in_line(
+ "#define TF_VERSION_SUFFIX \"%s\"" % old_version.identifier_string,
+ "#define TF_VERSION_SUFFIX \"%s\"" % new_version.identifier_string,
+ VERSION_H)
def update_setup_dot_py(old_version, new_version):
"""Update setup.py."""
- replace_line("_VERSION = '%s'" % old_version.string,
- "_VERSION = '%s'" % new_version.string, SETUP_PY)
+ replace_string_in_line("_VERSION = '%s'" % old_version.string,
+ "_VERSION = '%s'" % new_version.string, SETUP_PY)
def update_readme(old_version, new_version):
"""Update README."""
pep_440_str = new_version.pep_440_str
- replace_with_sed(r"s/%s\.%s\.([[:alnum:]]+)-/%s-/g" % (old_version.major,
- old_version.minor,
- pep_440_str),
- README_MD)
+ replace_string_in_line(r"%s\.%s\.([[:alnum:]]+)-" % (old_version.major,
+ old_version.minor),
+ "%s-" % pep_440_str, README_MD)
def update_md_files(old_version, new_version):
@@ -226,22 +224,29 @@ def update_md_files(old_version, new_version):
for filename in ["linux", "mac", "windows", "sources"]:
filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR,
filename)
- replace_with_sed("s/tensorflow-%s/tensorflow-%s/g"
- % (old_pep_version, new_pep_version), filepath)
- replace_with_sed("s/tensorflow_gpu-%s/tensorflow_gpu-%s/g"
- % (old_pep_version, new_pep_version), filepath)
- replace_with_sed("s/TensorFlow %s/TensorFlow %s/g"
- % (old_pep_version, new_pep_version), filepath)
+
+ if filename == "sources" and "rc0" in new_pep_version:
+ replace_string_in_line("(?<!<td>)tensorflow-%s" % old_pep_version,
+ "tensorflow-%s" % new_pep_version, filepath)
+ replace_string_in_line("(?<!<td>)tensorflow_gpu-%s" % old_pep_version,
+ "tensorflow_gpu-%s" % new_pep_version, filepath)
+ else:
+ replace_string_in_line("tensorflow-%s" % old_pep_version,
+ "tensorflow-%s" % new_pep_version, filepath)
+ replace_string_in_line("tensorflow_gpu-%s" % old_pep_version,
+ "tensorflow_gpu-%s" % new_pep_version, filepath)
+ replace_string_in_line("TensorFlow %s" % old_pep_version,
+ "TensorFlow %s" % new_pep_version, filepath)
for filename in ["java", "go", "c"]:
filepath = "%s/docs_src/install/install_%s.md" % (TF_SRC_DIR,
filename)
- replace_with_sed(r"s/x86_64-%s/x86_64-%s/g"
- % (old_version, new_version), filepath)
- replace_with_sed(r"s/libtensorflow-%s.jar/libtensorflow-%s.jar/g"
- % (old_version, new_version), filepath)
- replace_with_sed(r"s/<version>%s<\/version>/<version>%s<\/version>/g"
- % (old_version, new_version), filepath)
+ replace_string_in_line(r"x86_64-%s" % old_version,
+ "x86_64-%s" % new_version, filepath)
+ replace_string_in_line(r"libtensorflow-%s.jar" % old_version,
+ "libtensorflow-%s.jar" % new_version, filepath)
+ replace_string_in_line(r"<version>%s<\/version>" % old_version,
+ "<version>%s</version>" % new_version, filepath)
def major_minor_change(old_version, new_version):
@@ -266,10 +271,11 @@ def update_dockerfiles(old_version, new_version):
% (old_r_major_minor_string, r_major_minor_string))
# Update dockerfiles
- replace_with_sed("s/%s/%s/g"
- % (old_r_major_minor, r_major_minor), DEVEL_DOCKERFILE)
- replace_with_sed("s/%s/%s/g"
- % (old_r_major_minor, r_major_minor), GPU_DEVEL_DOCKERFILE)
+ replace_string_in_line(old_r_major_minor, r_major_minor, DEVEL_DOCKERFILE)
+ replace_string_in_line(old_r_major_minor, r_major_minor,
+ GPU_DEVEL_DOCKERFILE)
+ replace_string_in_line(old_r_major_minor, r_major_minor,
+ CPU_MKL_DEVEL_DOCKERFILE)
def check_for_lingering_string(lingering_string):
@@ -333,7 +339,7 @@ def main():
old_version = get_current_semver_version()
if args.nightly:
- # dev minor version is one ahead of official
+ # Dev minor version is one ahead of official.
nightly_minor_ver = int(old_version.minor) + 1
new_version = Version(old_version.major,
str(nightly_minor_ver),
@@ -349,12 +355,18 @@ def main():
update_md_files(old_version, new_version)
update_dockerfiles(old_version, new_version)
- # Print transition details
+ # Print transition details.
print("Major: %s -> %s" % (old_version.major, new_version.major))
print("Minor: %s -> %s" % (old_version.minor, new_version.minor))
print("Patch: %s -> %s\n" % (old_version.patch, new_version.patch))
check_for_old_version(old_version, new_version)
+ if "rc0" in str(new_version):
+ print("\n\n\033[93mNOTE: Please update the tensorflow/docs_src/install/"
+ "install_sources.md and add a line for tensorflow-%s and "
+ "tensorflow_gpu-%s in the tested source configurations "
+ "table.\033[0m\n" % (new_version.pep_440_str,
+ new_version.pep_440_str))
if __name__ == "__main__":
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index fa1cc73905..f678681dac 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -236,8 +236,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
new_col_offset = col - m.start(1) - 1
return line, new_col_offset
else:
- if (reversed_preceding_text=="" or
- reversed_preceding_text.isspace()):
+ if (reversed_preceding_text == "" or
+ reversed_preceding_text.isspace()):
line = line - 1
prev_line = self._lines[line - 1]
# TODO(aselle):
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index 39c040e051..c1b1f79bbd 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -17,7 +17,7 @@ cesnsu model:
./local_test.sh --model_name CENSUS_WIDENDEEP
-**2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the
+**2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the
test suite on it**
For example:
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index fa867b65db..b4fba5b8f5 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -34,6 +34,11 @@
# If set to a non-empty string, will use it as the URL from which the
# pip wheel file will be downloaded (instead of building the pip locally).
#
+# TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL
+# (Optional)
+# If set to a non-empty string, we will treat TF_DOCKER_BUILD_CENTRAL_PIP
+# as a path rather than a url.
+#
# TF_DOCKER_BUILD_IMAGE_NAME:
# (Optional)
# If set to any non-empty value, will use it as the image of the
@@ -234,6 +239,32 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
"COPY ${PIP_WHL} /\n"\
"RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \
> "${DOCKERFILE}"
+
+ # Build from a local whl file path rather than an URL
+ elif [[ ! -z "${TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL}" ]]; then
+ PIP_WHL="${TF_DOCKER_BUILD_CENTRAL_PIP}"
+ if [[ -z "${PIP_WHL}" ]]; then
+ die "ERROR: Cannot locate the specified pip whl file"
+ fi
+ echo "Specified PIP whl file is at: ${PIP_WHL}"
+
+ # Copy the pip file to tmp directory
+ cp "${PIP_WHL}" "${TMP_DIR}/" || \
+ die "ERROR: Failed to copy wheel file: ${PIP_WHL}"
+
+ # Use string replacement to put the correct file name into the Dockerfile
+ PIP_WHL=$(basename "${PIP_WHL}")
+
+ # Modify the non-devel Dockerfile to point to the correct pip whl file
+ # location
+ sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\
+"/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\
+"COPY ${PIP_WHL} /\n"\
+"RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \
+ > "${DOCKERFILE}"
+ echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}"
+ echo
+
else
echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}"
echo
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index c033c16ae9..b5df633800 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -323,7 +323,7 @@ class _Metadata(object):
"""
def __init__(self, name):
- """Creata a Metadata builder.
+ """Create a Metadata builder.
Args:
name: The name of the page being described by the Metadata block.
diff --git a/tensorflow/tools/graph_transforms/file_utils.h b/tensorflow/tools/graph_transforms/file_utils.h
index 4737e95abc..a3723f5cd3 100644
--- a/tensorflow/tools/graph_transforms/file_utils.h
+++ b/tensorflow/tools/graph_transforms/file_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
-#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
+#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
+#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -29,4 +29,4 @@ Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def);
} // namespace graph_transforms
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
+#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index f47df0e25d..d864d09d8f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -48,27 +48,6 @@ py_binary(
deps = ["//tensorflow:tensorflow_py"],
)
-py_test(
- name = "pip_smoke_test",
- srcs = ["pip_smoke_test.py"],
- data = [
- "//tensorflow:all_opensource_files",
- ],
- tags = [
- "manual",
- "notap",
- ],
-)
-
-py_binary(
- name = "check_load_py_test",
- srcs = ["check_load_py_test.py"],
- data = [
- "//tensorflow:all_opensource_files",
- ],
- srcs_version = "PY2AND3",
-)
-
# On Windows, python binary is a zip file of runfiles tree.
# Add everything to its data dependency for generating a runfiles tree
# for building the pip package on Windows.
@@ -174,7 +153,8 @@ sh_binary(
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/predictor:predictor_pip",
"//tensorflow/contrib/py2tf:py2tf_internal",
- "//tensorflow/contrib/py2tf/convert:convert",
+ "//tensorflow/contrib/py2tf/converters:converters",
+ "//tensorflow/contrib/py2tf/converters:test_lib",
"//tensorflow/contrib/py2tf/pyct:pyct",
"//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/receptive_field:receptive_field_pip",
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index ca8c272a08..dc31e4c5f7 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -137,8 +137,8 @@ function main() {
fi
fi
fi
- # Install toco as a binary in aux-bin.
mkdir "${TMPDIR}/tensorflow/aux-bin"
+ # Install toco as a binary in aux-bin.
cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/
fi
diff --git a/tensorflow/tools/pip_package/check_load_py_test.py b/tensorflow/tools/pip_package/check_load_py_test.py
index 79d11b08ce..e2fe1121d7 100644
--- a/tensorflow/tools/pip_package/check_load_py_test.py
+++ b/tensorflow/tools/pip_package/check_load_py_test.py
@@ -22,6 +22,9 @@ import os
import subprocess
+os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
+
+
def check_output_despite_error(args):
"""Get output of args from command line, even if there are errors.
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index cddf9c8f44..38a9007387 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""This pip smoke test verifies dependency files exist in the pip package.
This script runs bazel queries to see what python files are required by the
@@ -23,11 +22,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import subprocess
-PIP_PACKAGE_QUERY_EXPRESSION = \
- 'deps(//tensorflow/tools/pip_package:build_pip_package)'
+os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
+
+PIP_PACKAGE_QUERY_EXPRESSION = (
+ "deps(//tensorflow/tools/pip_package:build_pip_package)")
+# pylint: disable=g-backslash-continuation
PY_TEST_QUERY_EXPRESSION = 'deps(\
filter("^((?!benchmark).)*$",\
kind(py_test,\
@@ -35,6 +38,7 @@ PY_TEST_QUERY_EXPRESSION = 'deps(\
+ //tensorflow/contrib/... \
- //tensorflow/contrib/tensorboard/... \
- attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
+# pylint: enable=g-backslash-continuation
# Hard-coded blacklist of files if not included in pip package
# TODO(amitpatankar): Clean up blacklist.
@@ -85,15 +89,15 @@ def main():
"""
# pip_package_dependencies_list is the list of included files in pip packages
- pip_package_dependencies = subprocess.check_output([
- 'bazel', 'query', PIP_PACKAGE_QUERY_EXPRESSION])
+ pip_package_dependencies = subprocess.check_output(
+ ["bazel", "query", PIP_PACKAGE_QUERY_EXPRESSION])
pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
print("Pip package superset size: %d" % len(pip_package_dependencies_list))
# tf_py_test_dependencies is the list of dependencies for all python
# tests in tensorflow
- tf_py_test_dependencies = subprocess.check_output([
- 'bazel', 'query', PY_TEST_QUERY_EXPRESSION])
+ tf_py_test_dependencies = subprocess.check_output(
+ ["bazel", "query", PY_TEST_QUERY_EXPRESSION])
tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
@@ -114,8 +118,7 @@ def main():
# Check if the dependency is in the pip package, the blacklist, or
# should be ignored because of its file extension
- if not (ignore or
- dependency in pip_package_dependencies_list or
+ if not (ignore or dependency in pip_package_dependencies_list or
dependency in BLACKLIST):
missing_dependencies.append(dependency)
@@ -126,19 +129,20 @@ def main():
for missing_dependency in missing_dependencies:
print("\nMissing dependency: %s " % missing_dependency)
print("Affected Tests:")
- rdep_query = 'rdeps(kind(py_test, \
- //tensorflow/python/...), %s)' % missing_dependency
- affected_tests = subprocess.check_output(['bazel', 'query', rdep_query])
+ rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
+ missing_dependency)
+ affected_tests = subprocess.check_output(["bazel", "query", rdep_query])
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))
raise RuntimeError("""One or more dependencies are not in the pip package.
Please either blacklist the dependencies in
-tensorflow/tensorflow/tensorflow/tools/pip_package/pip_smoke_test.py
-or add them to tensorflow/tensorflow/tensorflow/tools/pip_package/BUILD.""")
+//tensorflow/tools/pip_package/pip_smoke_test.py
+or add them to //tensorflow/tools/pip_package/BUILD.""")
else:
print("TEST PASSED")
+
if __name__ == "__main__":
main()
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
index 44387bbd4d..e18d749cff 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#ifndef TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#define TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ ProtoTextFunctionCode GetProtoTextFunctionCode(
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#endif // TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD
index 28d651e910..159a8c1cfb 100644
--- a/tensorflow/tools/test/BUILD
+++ b/tensorflow/tools/test/BUILD
@@ -104,12 +104,3 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
-
-py_test(
- name = "check_futures_test",
- size = "small",
- srcs = ["check_futures_test.py"],
- data = ["//tensorflow:all_opensource_files"],
- srcs_version = "PY2AND3",
- deps = ["@six_archive//:six"],
-)
diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py
index 1c07511888..9181c9bd4a 100644
--- a/tensorflow/tools/test/check_futures_test.py
+++ b/tensorflow/tools/test/check_futures_test.py
@@ -33,7 +33,7 @@ import re
import six
-BASE_DIR = os.path.normpath(os.path.join(__file__, '../../..'))
+BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
FUTURES_PATTERN = re.compile(r'^from __future__ import (\w+)\s*$')
FUTURES_PATTERN_2 = re.compile(
r'^from __future__ import (\w+), (\w+), (\w+)\s*$')
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8850610cdb..f7d9075032 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,11 +1,12 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
-load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
+load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
+load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure")
load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure")
load("//third_party:repo.bzl", "tf_http_archive")
load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external")
@@ -66,8 +67,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# files, in case the parsing of those build files depends on the bazel
# version we require here.
check_bazel_version_at_least("0.5.4")
+ clang6_configure(name="local_config_clang6")
cuda_configure(name="local_config_cuda")
- trt_repository(name="local_config_tensorrt")
+ tensorrt_configure(name="local_config_tensorrt")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")
@@ -475,11 +477,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/f78b1b74e8a1c265f84ccd2142af88e346ce721e.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/f78b1b74e8a1c265f84ccd2142af88e346ce721e.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/11a2ca6eea8a7fe240a14c0c35fd2017341279be.tar.gz",
],
- sha256 = "d8e5966cf8e7489fa5a1b167f83bebaabe58aa7584b6d8a5c8f3722ae3047d19",
- strip_prefix = "llvm-f78b1b74e8a1c265f84ccd2142af88e346ce721e",
+ sha256 = "b5429ccf8d57273cb8489714f728c997cd720ec66fc2c0292422ab8f0e729ce0",
+ strip_prefix = "llvm-11a2ca6eea8a7fe240a14c0c35fd2017341279be",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
)
@@ -677,11 +679,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "bazel_toolchains",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/b49ba3689f46ac50e9277dafd8ff32b26951f82e.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/b49ba3689f46ac50e9277dafd8ff32b26951f82e.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/f3b09700fae5d7b6e659d7cefe0dcc6e8498504c.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/f3b09700fae5d7b6e659d7cefe0dcc6e8498504c.tar.gz",
],
- sha256 = "1266f1e27b4363c83222f1a776397c7a069fbfd6aacc9559afa61cdd73e1b429",
- strip_prefix = "bazel-toolchains-b49ba3689f46ac50e9277dafd8ff32b26951f82e",
+ sha256 = "ed829b5eea8af1f405f4cc3d6ecfc3b1365bb7843171036030a31b5127002311",
+ strip_prefix = "bazel-toolchains-f3b09700fae5d7b6e659d7cefe0dcc6e8498504c",
)
tf_http_archive(
diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD
index bf5310aa16..2dc921933c 100644
--- a/third_party/aws.BUILD
+++ b/third_party/aws.BUILD
@@ -75,7 +75,7 @@ cc_library(
"aws-cpp-sdk-s3/include/",
],
deps = [
- "@curl//:curl",
+ "@curl",
],
)
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index f5f3418527..f661093bc9 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -36,7 +36,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"@eigen_archive//:eigen",
- "@local_config_sycl//sycl:sycl",
+ "@local_config_sycl//sycl",
],
)
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
index c210b1712c..cb1636256d 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
@@ -1,5 +1,5 @@
-#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
-#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
+#ifndef EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
+#define EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
#ifdef _MSC_VER
@@ -502,4 +502,4 @@ struct functor_traits<scalar_product_op<QInt32, double>> {
} // end namespace internal
} // end namespace Eigen
-#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
+#endif // EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
index 7a222fddc1..8f9906dbf9 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
@@ -1,5 +1,5 @@
-#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
-#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
+#ifndef EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
+#define EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
#include "PacketMathAVX2.h"
@@ -542,4 +542,4 @@ EIGEN_STRONG_INLINE QInt8 predux_max<Packet64q8i>(const Packet64q8i& a) {
} // end namespace internal
} // end namespace Eigen
-#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
+#endif // EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_PACKETMATHAVX512_H_
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
index 045384d7fc..7b4ecc752f 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
@@ -1,5 +1,5 @@
-#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
-#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
+#ifndef EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
+#define EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
namespace Eigen {
namespace internal {
@@ -63,4 +63,4 @@ pcast<Packet8q32i, Packet32q8u>(const Packet8q32i& a, const Packet8q32i& b,
} // end namespace internal
} // end namespace Eigen
-#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
+#endif // EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX2_H_
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
index cd7120ec00..26735743d4 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
@@ -1,5 +1,5 @@
-#ifndef THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
-#define THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
+#ifndef EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
+#define EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
namespace Eigen {
namespace internal {
@@ -177,4 +177,4 @@ pcast<Packet16q32i, Packet32q16u>(const Packet16q32i& a,
} // end namespace internal
} // end namespace Eigen
-#endif // THIRD_PARTY_EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
+#endif // EIGEN3_UNSUPPORTED_EIGEN_CXX11_SRC_FIXEDPOINT_TYPECASTINGAVX512_H_
diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h
index 252cc01fec..31b4935089 100644
--- a/third_party/fft2d/fft.h
+++ b/third_party/fft2d/fft.h
@@ -15,8 +15,8 @@ limitations under the License.
// Declarations for 1D FFT routines in third_party/fft2d/fft.
-#ifndef THIRD_PARTY_FFT2D_FFT_H__
-#define THIRD_PARTY_FFT2D_FFT_H__
+#ifndef FFT2D_FFT_H__
+#define FFT2D_FFT_H__
#ifdef __cplusplus
extern "C" {
@@ -33,4 +33,4 @@ extern void dfst(int, double *, double *, int *, double *);
}
#endif
-#endif // THIRD_PARTY_FFT2D_FFT_H__
+#endif // FFT2D_FFT_H__
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 2727fa5efe..8e1dd8a54f 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -236,7 +236,7 @@ def _cudnn_install_basedir(repository_ctx):
return cudnn_install_path
-def _matches_version(environ_version, detected_version):
+def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
This function performs a weak matching so that if the user specifies only the
@@ -317,7 +317,7 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
environ_version = ""
if _TF_CUDA_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
auto_configure_fail(
("CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)") % (full_version, environ_version))
@@ -338,35 +338,49 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-def _find_cuda_define(repository_ctx, cudnn_header_dir, define):
- """Returns the value of a #define in cudnn.h
+def find_cuda_define(repository_ctx, header_dir, header_file, define):
+ """Returns the value of a #define in a header file.
- Greps through cudnn.h and returns the value of the specified #define. If the
- #define is not found, then raise an error.
+ Greps through a header file and returns the value of the specified #define.
+ If the #define is not found, then raise an error.
Args:
repository_ctx: The repository context.
- cudnn_header_dir: The directory containing the cuDNN header.
+ header_dir: The directory containing the header file.
+ header_file: The header file name.
define: The #define to search for.
Returns:
- The value of the #define found in cudnn.h.
+ The value of the #define found in the header.
"""
- # Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR.
- cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir)
- if not cudnn_h_path.exists:
- auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
- result = repository_ctx.execute(["grep", "--color=never", "-E", define, str(cudnn_h_path)])
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
if result.stderr:
- auto_configure_fail("Error reading %s: %s" %
- (result.stderr, str(cudnn_h_path)))
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
- # Parse the cuDNN major version from the line defining CUDNN_MAJOR
- lines = result.stdout.splitlines()
- if len(lines) == 0 or lines[0].find(define) == -1:
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, str(cudnn_h_path)))
- return lines[0].replace(define, "").strip()
+ (define, h_path))
+ version = result.stdout
+ # Remove the new line and '\' character if any.
+ version = version.replace("\\", " ")
+ version = version.replace("\n", " ")
+ version = version.replace(define, "").lstrip()
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)))
+ version = version[:version_end].strip()
+ return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
@@ -382,12 +396,12 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""
cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
cudnn_install_basedir)
- major_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MAJOR)
- minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MINOR)
- patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_PATCHLEVEL)
+ major_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
+ minor_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
+ patch_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
# Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
@@ -395,7 +409,7 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
environ_version = ""
if _TF_CUDNN_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
cudnn_install_basedir)
auto_configure_fail(
@@ -427,7 +441,7 @@ def _compute_capabilities(repository_ctx):
return capabilities
-def _cpu_value(repository_ctx):
+def get_cpu_value(repository_ctx):
"""Returns the name of the host operating system.
Args:
@@ -447,7 +461,7 @@ def _cpu_value(repository_ctx):
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
- return _cpu_value(repository_ctx) == "Windows"
+ return get_cpu_value(repository_ctx) == "Windows"
def _lib_name(lib, cpu_value, version="", static=False):
"""Constructs the platform-specific name of a library.
@@ -582,11 +596,8 @@ def _find_libs(repository_ctx, cuda_config):
cuda_config: The CUDA config as returned by _get_cuda_config
Returns:
- Map of library names to structs of filename and path as returned by
- _find_cuda_lib and _find_cupti_lib.
+ Map of library names to structs of filename and path.
"""
- cudnn_version = cuda_config.cudnn_version
- cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
cpu_value = cuda_config.cpu_value
return {
"cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
@@ -611,7 +622,7 @@ def _find_libs(repository_ctx, cuda_config):
"cudnn": _find_cuda_lib(
"cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
cuda_config.cudnn_version),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config),
+ "cupti": _find_cupti_lib(repository_ctx, cuda_config)
}
@@ -654,7 +665,7 @@ def _get_cuda_config(repository_ctx):
compute_capabilities: A list of the system's CUDA compute capabilities.
cpu_value: The name of the host operating system.
"""
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
@@ -712,13 +723,13 @@ error_gpu_disabled()
def _create_dummy_repository(repository_ctx):
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
# Set up BUILD file for cuda/.
_tpl(repository_ctx, "cuda:build_defs.bzl",
{
"%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]"
+ "%{cuda_extra_copts}": "[]",
})
_tpl(repository_ctx, "cuda:BUILD",
{
@@ -805,8 +816,8 @@ def _norm_path(path):
return path
-def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
- src_files = [], dest_files = []):
+def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
+ src_files = [], dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
@@ -913,11 +924,11 @@ def _create_local_cuda_repository(repository_ctx):
# cuda_toolkit_path
cuda_toolkit_path = cuda_config.cuda_toolkit_path
cuda_include_path = cuda_toolkit_path + "/include"
- genrules = [_symlink_genrule_for_dir(repository_ctx,
+ genrules = [symlink_genrule_for_dir(repository_ctx,
cuda_include_path, "cuda/include", "cuda-include")]
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm"))
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/extras/CUPTI/include",
"cuda/extras/CUPTI/include", "cuda-extras"))
@@ -927,15 +938,15 @@ def _create_local_cuda_repository(repository_ctx):
for lib in cuda_libs.values():
cuda_lib_src.append(lib.path)
cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
- cuda_lib_src, cuda_lib_dest))
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
+ cuda_lib_src, cuda_lib_dest))
- # Set up the symbolic links for cudnn if cudnn was was not installed to
+ # Set up the symbolic links for cudnn if cndnn was not installed to
# CUDA_TOOLKIT_PATH.
included_files = _read_dir(repository_ctx, cuda_include_path).replace(
cuda_include_path, '').splitlines()
if '/cudnn.h' not in included_files:
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None,
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None,
"cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
["cudnn.h"]))
else:
@@ -952,7 +963,6 @@ def _create_local_cuda_repository(repository_ctx):
"%{cuda_is_configured}": "True",
"%{cuda_extra_copts}": _compute_cuda_extra_copts(
repository_ctx, cuda_config.compute_capabilities),
-
})
_tpl(repository_ctx, "cuda:BUILD",
{
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD
index 527a08c4b3..ca2d38d687 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/jpeg.BUILD
@@ -34,6 +34,10 @@ libjpegturbo_copts = select({
"-mfloat-abi=softfp",
"-fprefetch-loop-arrays",
],
+ ":linux_ppc64le": [
+ "-mcpu=power8",
+ "-mtune=power8",
+ ],
"//conditions:default": [],
})
@@ -123,11 +127,51 @@ cc_library(
":k8": [":simd_x86_64"],
":armeabi-v7a": [":simd_armv7a"],
":arm64-v8a": [":simd_armv8a"],
+ ":linux_ppc64le": [":simd_altivec"],
"//conditions:default": [":simd_none"],
}),
)
cc_library(
+ name = "simd_altivec",
+ srcs = [
+ "jchuff.h",
+ "jconfig.h",
+ "jdct.h",
+ "jerror.h",
+ "jinclude.h",
+ "jmorecfg.h",
+ "jpegint.h",
+ "jpeglib.h",
+ "jsimd.h",
+ "jsimddct.h",
+ "simd/jsimd.h",
+ "simd/jccolor-altivec.c",
+ "simd/jcgray-altivec.c",
+ "simd/jcsample-altivec.c",
+ "simd/jdcolor-altivec.c",
+ "simd/jdmerge-altivec.c",
+ "simd/jdsample-altivec.c",
+ "simd/jfdctfst-altivec.c",
+ "simd/jfdctint-altivec.c",
+ "simd/jidctfst-altivec.c",
+ "simd/jidctint-altivec.c",
+ "simd/jquanti-altivec.c",
+ "simd/jsimd_powerpc.c",
+ "simd/jsimd_altivec.h",
+ "simd/jcsample.h",
+ ],
+ hdrs = [
+ "simd/jdmrgext-altivec.c", # should have been named .inc
+ "simd/jccolext-altivec.c", # should have been named .inc
+ "simd/jcgryext-altivec.c", # should have been named .inc
+ "simd/jdcolext-altivec.c", # should have been named .inc
+ ],
+ copts = libjpegturbo_copts,
+ nocopts = libjpegturbo_nocopts,
+)
+
+cc_library(
name = "simd_x86_64",
srcs = [
"jchuff.h",
@@ -219,7 +263,7 @@ genrule(
" -o $$out" +
" $$(dirname $(location simd/jdct.inc))/$$(basename $${out%.o}.asm)\n" +
"done",
- tools = ["@nasm//:nasm"],
+ tools = ["@nasm"],
)
cc_library(
@@ -381,6 +425,7 @@ genrule(
":k8": "cp $(location jconfig_nowin_simd.h) $@",
":armeabi-v7a": "cp $(location jconfig_nowin_simd.h) $@",
":arm64-v8a": "cp $(location jconfig_nowin_simd.h) $@",
+ ":linux_ppc64le": "cp $(location jconfig_nowin_simd.h) $@",
"//conditions:default": "cp $(location jconfig_nowin_nosimd.h) $@",
}),
)
@@ -498,3 +543,9 @@ config_setting(
name = "windows_msvc",
values = {"cpu": "x64_windows_msvc"},
)
+
+config_setting(
+ name = "linux_ppc64le",
+ values = {"cpu": "ppc"},
+
+)
diff --git a/third_party/swig.BUILD b/third_party/swig.BUILD
index d698fa934b..f2f647401b 100644
--- a/third_party/swig.BUILD
+++ b/third_party/swig.BUILD
@@ -89,7 +89,7 @@ cc_binary(
],
output_licenses = ["unencumbered"],
visibility = ["//visibility:public"],
- deps = ["@pcre//:pcre"],
+ deps = ["@pcre"],
)
filegroup(
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
index a8e52d13d3..6cb7db7e90 100644
--- a/third_party/tensorrt/BUILD.tpl
+++ b/third_party/tensorrt/BUILD.tpl
@@ -1,38 +1,69 @@
-# -*- python -*-
+# NVIDIA TensorRT
+# A high-performance deep learning inference optimizer and runtime.
-licenses(["notice"])
+licenses(["notice"])
exports_files(["LICENSE"])
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
-config_setting(
- name = "trt_enabled",
- define_values = {
- "using_tensorrt":"true"
- },
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "tensorrt_headers",
+ hdrs = [%{tensorrt_headers}],
+ includes = [
+ "include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nv_infer",
+ srcs = [%{nv_infer}],
+ data = [%{nv_infer}],
+ includes = [
+ "include",
+ ],
+ copts= cuda_default_copts(),
+ deps = [
+ "@local_config_cuda//cuda:cuda",
+ ":tensorrt_headers",
+ ],
+ linkstatic = 1,
visibility = ["//visibility:public"],
)
cc_library(
- name = "tensorrt",
- srcs =[%{tensorrt_lib}],
- hdrs = ["include/NvInfer.h",
- "include/NvUtils.h",
+ name = "nv_infer_plugin",
+ srcs = [%{nv_infer_plugin}],
+ data = [%{nv_infer_plugin}],
+ includes = [
+ "include",
],
copts= cuda_default_copts(),
- deps =["@local_config_cuda//cuda:cuda",
- "@local_config_cuda//cuda:cudnn",],
+ deps = [
+ "@local_config_cuda//cuda:cuda",
+ ":nv_infer",
+ ":tensorrt_headers",
+ ],
linkstatic = 1,
- #include_prefix="include/",
- includes=["include/"],
- visibility = ["//visibility:public"],
+ visibility = ["//visibility:public"],
)
-%{tensorrt_genrules}
+cc_library(
+ name = "nv_parsers",
+ srcs = [%{nv_parsers}],
+ data = [%{nv_parsers}],
+ includes = [
+ "include",
+ ],
+ copts= cuda_default_copts(),
+ deps = [
+ ":tensorrt_headers",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
-# filegroup(
-# name = "%{tensorrt_lib}",
-# srcs = ["%{tensorrt_lib}"],
-# visibility = ["//visibility:public"],
-# )
+%{tensorrt_genrules} \ No newline at end of file
diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl
index 18f354ee5a..8a89b59bc8 100644
--- a/third_party/tensorrt/build_defs.bzl.tpl
+++ b/third_party/tensorrt/build_defs.bzl.tpl
@@ -1,18 +1,7 @@
-# -*- python -*-
-"""
-template file for trt functions
+# Build configurations for TensorRT.
-"""
-
-def is_trt_enabled():
- return %{trt_configured}
-
-def if_trt(if_true,if_false=[]):
- # if is_trt_enabled():
- # return if_true
- # return if_false
-
- return select({
- "@local_config_tensorrt//:trt_enabled":if_true,
- "//conditions:default":if_false,
- })
+def if_tensorrt(if_true, if_false=[]):
+ """Tests whether TensorRT was enabled during the configure process."""
+ if %{tensorrt_is_configured}:
+ return if_true
+ return if_false \ No newline at end of file
diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl
new file mode 100644
index 0000000000..8aa0f28f39
--- /dev/null
+++ b/third_party/tensorrt/tensorrt_configure.bzl
@@ -0,0 +1,224 @@
+# -*- Python -*-
+"""Repository rule for TensorRT configuration.
+
+`tensorrt_configure` depends on the following environment variables:
+
+ * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
+ * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
+"""
+
+load(
+ "//third_party/gpus:cuda_configure.bzl",
+ "auto_configure_fail",
+ "get_cpu_value",
+ "find_cuda_define",
+ "matches_version",
+ "symlink_genrule_for_dir",
+)
+
+_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
+_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
+
+_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"]
+_TF_TENSORRT_HEADERS = [
+ "NvInfer.h", "NvInferPlugin.h", "NvCaffeParser.h", "NvUffParser.h",
+ "NvUtils.h"
+]
+
+_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
+_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
+_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
+
+
+def _headers_exist(repository_ctx, path):
+ """Returns whether all TensorRT header files could be found in 'path'.
+
+ Args:
+ repository_ctx: The repository context.
+ path: The TensorRT include path to check.
+
+ Returns:
+ True if all TensorRT header files can be found in the path.
+ """
+ for h in _TF_TENSORRT_HEADERS:
+ if not repository_ctx.path("%s/%s" % (path, h)).exists:
+ return False
+ return True
+
+
+def _find_trt_header_dir(repository_ctx, trt_install_path):
+ """Returns the path to the directory containing headers of TensorRT.
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library install directory.
+
+ Returns:
+ The path of the directory containing the TensorRT header.
+ """
+ if trt_install_path == "/usr/lib/x86_64-linux-gnu":
+ path = "/usr/include/x86_64-linux-gnu"
+ if _headers_exist(repository_ctx, path):
+ return path
+ path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath)
+ if _headers_exist(repository_ctx, path):
+ return path
+ auto_configure_fail(
+ "Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path)
+
+
+def _trt_lib_version(repository_ctx, trt_install_path):
+ """Detects the library (e.g. libnvinfer) version of TensorRT.
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library install directory.
+
+ Returns:
+ A string containing the library version of TensorRT.
+ """
+ trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
+ major_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_MAJOR)
+ minor_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_MINOR)
+ patch_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_PATCH)
+ full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+ environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip()
+ if not matches_version(environ_version, full_version):
+ auto_configure_fail(
+ ("TensorRT library version detected from %s/%s (%s) does not match " +
+ "TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") %
+ (trt_header_dir, "NvInfer.h", full_version, environ_version))
+ return environ_version
+
+
+def _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version):
+ """Finds the given TensorRT library on the system.
+
+ Adapted from code contributed by Sami Kama (https://github.com/samikama).
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library installation directory.
+ trt_lib_version: The version of TensorRT library files as returned
+ by _trt_lib_version.
+
+ Returns:
+ Map of library names to structs with the following fields:
+ src_file_path: The full path to the library found on the system.
+ dst_file_name: The basename of the target library.
+ """
+ objdump = repository_ctx.which("objdump")
+ result = {}
+ for lib in _TF_TENSORRT_LIBS:
+ dst_file_name = "lib%s.so.%s" % (lib, trt_lib_version)
+ src_file_path = repository_ctx.path("%s/%s" % (trt_install_path,
+ dst_file_name))
+ if not src_file_path.exists:
+ auto_configure_fail(
+ "Cannot find TensorRT library %s" % str(src_file_path))
+ if objdump != None:
+ objdump_out = repository_ctx.execute([objdump, "-p", str(src_file_path)])
+ for line in objdump_out.stdout.splitlines():
+ if "SONAME" in line:
+ dst_file_name = line.strip().split(" ")[-1]
+ result.update({
+ lib:
+ struct(
+ dst_file_name=dst_file_name,
+ src_file_path=str(src_file_path.realpath))
+ })
+ return result
+
+
+def _tpl(repository_ctx, tpl, substitutions):
+ repository_ctx.template(tpl, Label("//third_party/tensorrt:%s.tpl" % tpl),
+ substitutions)
+
+
+def _create_dummy_repository(repository_ctx):
+ """Create a dummy TensorRT repository."""
+ _tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "False"})
+ substitutions = {
+ "%{tensorrt_genrules}": "",
+ "%{tensorrt_headers}": "",
+ }
+ for lib in _TF_TENSORRT_LIBS:
+ k = "%%{%s}" % lib.replace("nv", "nv_")
+ substitutions.update({k: ""})
+ _tpl(repository_ctx, "BUILD", substitutions)
+
+
+def _tensorrt_configure_impl(repository_ctx):
+ """Implementation of the tensorrt_configure repository rule."""
+ if _TENSORRT_INSTALL_PATH not in repository_ctx.os.environ:
+ _create_dummy_repository(repository_ctx)
+ return
+
+ if (get_cpu_value(repository_ctx) != "Linux"):
+ auto_configure_fail("TensorRT is supported only on Linux.")
+ if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
+ auto_configure_fail("TensorRT library (libnvinfer) version is not set.")
+ trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
+ if not repository_ctx.path(trt_install_path).exists:
+ auto_configure_fail(
+ "Cannot find TensorRT install path %s." % trt_install_path)
+
+ # Set up the symbolic links for the library files.
+ trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
+ trt_libs = _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version)
+ trt_lib_src = []
+ trt_lib_dest = []
+ for lib in trt_libs.values():
+ trt_lib_src.append(lib.src_file_path)
+ trt_lib_dest.append(lib.dst_file_name)
+ genrules = [
+ symlink_genrule_for_dir(repository_ctx, None, "tensorrt/lib/",
+ "tensorrt_lib", trt_lib_src, trt_lib_dest)
+ ]
+
+ # Set up the symbolic links for the header files.
+ trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
+ src_files = [
+ "%s/%s" % (trt_header_dir, header) for header in _TF_TENSORRT_HEADERS
+ ]
+ dest_files = _TF_TENSORRT_HEADERS
+ genrules.append(
+ symlink_genrule_for_dir(repository_ctx, None, "tensorrt/include/",
+ "tensorrt_include", src_files, dest_files))
+
+ # Set up config file.
+ _tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "True"})
+
+ # Set up BUILD file.
+ substitutions = {
+ "%{tensorrt_genrules}": "\n".join(genrules),
+ "%{tensorrt_headers}": '":tensorrt_include"',
+ }
+ for lib in _TF_TENSORRT_LIBS:
+ k = "%%{%s}" % lib.replace("nv", "nv_")
+ v = '"tensorrt/lib/%s"' % trt_libs[lib].dst_file_name
+ substitutions.update({k: v})
+ _tpl(repository_ctx, "BUILD", substitutions)
+
+
+tensorrt_configure = repository_rule(
+ implementation=_tensorrt_configure_impl,
+ environ=[
+ _TENSORRT_INSTALL_PATH,
+ _TF_TENSORRT_VERSION,
+ ],
+)
+"""Detects and configures the local CUDA toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+tensorrt_configure(name = "local_config_tensorrt")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/toolchains/clang6/BUILD b/third_party/toolchains/clang6/BUILD
new file mode 100644
index 0000000000..ffd0fb0cdc
--- /dev/null
+++ b/third_party/toolchains/clang6/BUILD
@@ -0,0 +1 @@
+package(default_visibility = ["//visibility:public"])
diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl
new file mode 100644
index 0000000000..6b7e5a8808
--- /dev/null
+++ b/third_party/toolchains/clang6/CROSSTOOL.tpl
@@ -0,0 +1,587 @@
+major_version: "v1"
+minor_version: "llvm:6.0.0"
+default_target_cpu: "k8"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "k8-clang-6.0-cxx-4.8-linux-gnu"
+}
+
+toolchain {
+ compiler: "clang6" # bazel build --compiler=clang6
+ target_cpu: "k8" # bazel build --cpu=k8
+ target_libc: "GLIBC_2.19" # bazel build --glibc=GLIBC_2.19
+
+ abi_libc_version: "2.19"
+ abi_version: "gcc-4.8-cxx11"
+ builtin_sysroot: ""
+ cc_target_os: "linux-gnu"
+ default_python_version: "python2.7"
+ dynamic_runtimes_filegroup: "dynamic-runtime-libs-k8"
+ host_system_name: "x86_64-unknown-linux-gnu"
+ needsPic: true
+ static_runtimes_filegroup: "static-runtime-libs-k8"
+ supports_embedded_runtimes: true
+ supports_fission: true
+ supports_gold_linker: true
+ supports_incremental_linker: true
+ supports_interface_shared_objects: true
+ supports_normalizing_ar: true
+ supports_start_end_lib: true
+ supports_thin_archives: true
+ target_system_name: "x86_64-unknown-linux-gnu"
+ toolchain_identifier: "k8-clang-6.0-cxx-4.8-linux-gnu"
+
+ tool_path { name: "ar" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-ar" }
+ tool_path { name: "as" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-as" }
+ tool_path { name: "compat-ld" path: "%package(@local_config_clang6//clang6)%/llvm/bin/ld.lld" }
+ tool_path { name: "cpp" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-cpp" }
+ tool_path { name: "dwp" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-dwp" }
+ tool_path { name: "gcc" path: "%package(@local_config_clang6//clang6)%/llvm/bin/clang" }
+ tool_path { name: "gcov" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-cov" }
+ tool_path { name: "ld" path: "%package(@local_config_clang6//clang6)%/llvm/bin/ld.lld" }
+ tool_path { name: "llvm-profdata" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-profdata" }
+ tool_path { name: "nm" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-nm" }
+ tool_path { name: "objcopy" path: "%package(@local_config_clang6//clang6)%/llvm/bin/llvm-objcopy" }
+ tool_path { name: "objdump" path: "%package(@local_config_clang6//clang6)%/sbin/objdump" }
+ tool_path { name: "strip" path: "%package(@local_config_clang6//clang6)%/sbin/strip" }
+
+ unfiltered_cxx_flag: "-no-canonical-prefixes"
+
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
+ unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
+
+ objcopy_embed_flag: "-I"
+ objcopy_embed_flag: "binary"
+
+ # This action_config makes features flags propagate
+ # to CC_FLAGS for genrules, and eventually skylark.
+ action_config {
+ action_name: "cc-flags-make-variable"
+ config_name: "cc-flags-make-variable"
+ }
+
+ # Security hardening on by default.
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now have
+ # it enabled by default.
+ compiler_flag: "-U_FORTIFY_SOURCE"
+ compiler_flag: "-D_FORTIFY_SOURCE=1"
+ compiler_flag: "-fstack-protector"
+ linker_flag: "-Wl,-z,relro,-z,now"
+
+ # This adds a little bit more durability to our Clang build.
+ #
+ # At the moment, this only only be needed for:
+ # - add_boringssl_s390x.patch: --Wa,--noexecstack
+ #
+ # Folks who do maintenance work on TF Bazel Clang should consider
+ # commenting out these lines, while doing that work, to gain a better
+ # understanding of what the intersection of support looks like between GCC
+ # and Clang. Please note that, Bazel does not support
+ # -Xclang-only / -Xgcc-only.
+ compiler_flag: "-Wno-unknown-warning-option"
+ compiler_flag: "-Wno-unused-command-line-argument"
+ compiler_flag: "-Wno-ignored-optimization-argument"
+
+ #### Common compiler options. ####
+ compiler_flag: "-D_REENTRANT"
+ compiler_flag: "-D__STDC_FORMAT_MACROS"
+ compiler_flag: "-DSUPPRESS_USE_FILE_OFFSET64"
+ compiler_flag: "-Wall"
+ compiler_flag: "-Wformat-security"
+ compiler_flag: "-Wframe-larger-than=16384"
+ compiler_flag: "-Wno-char-subscripts"
+ compiler_flag: "-Wno-error=deprecated-declarations"
+ compiler_flag: "-Wno-uninitialized"
+ compiler_flag: "-Wno-sign-compare"
+ compiler_flag: "-Wno-strict-overflow"
+ compiler_flag: "-Wno-unused-function"
+ compiler_flag: "-fdiagnostics-show-option"
+ compiler_flag: "-fmessage-length=0"
+ compiler_flag: "-fno-exceptions"
+ compiler_flag: "-fno-omit-frame-pointer"
+ compiler_flag: "-fno-strict-aliasing"
+ compiler_flag: "-fno-use-init-array"
+ compiler_flag: "-funsigned-char"
+ compiler_flag: "-gmlt"
+ cxx_flag: "-Wno-deprecated"
+ cxx_flag: "-Wno-invalid-offsetof" # Needed for protobuf code (2017-11-07)
+ cxx_flag: "-fshow-overloads=best"
+ compiler_flag: "-Wthread-safety-analysis"
+
+ # Python extensions unfortunately make this go wild.
+ compiler_flag: "-Wno-writable-strings"
+
+ # GCC's warning produces too many false positives:
+ cxx_flag: "-Woverloaded-virtual"
+ cxx_flag: "-Wnon-virtual-dtor"
+
+ # Enable coloring even if there's no attached terminal. Bazel removes the
+ # escape sequences if --nocolor is specified. This isn't supported by gcc
+ # on Ubuntu 14.04.
+ compiler_flag: "-fcolor-diagnostics"
+
+ # Disable some broken warnings from Clang.
+ compiler_flag: "-Wno-ambiguous-member-template"
+ compiler_flag: "-Wno-pointer-sign"
+
+ # These warnings have a low signal to noise ratio.
+ compiler_flag: "-Wno-reserved-user-defined-literal"
+ compiler_flag: "-Wno-return-type-c-linkage"
+ compiler_flag: "-Wno-invalid-source-encoding"
+
+ # Per default we switch off any layering related warnings.
+ compiler_flag: "-Wno-private-header"
+
+ # Clang-specific warnings that we explicitly enable for TensorFlow. Some of
+ # these aren't on by default, or under -Wall, or are subsets of warnings
+ # turned off above.
+ compiler_flag: "-Wfloat-overflow-conversion"
+ compiler_flag: "-Wfloat-zero-conversion"
+ compiler_flag: "-Wfor-loop-analysis"
+ compiler_flag: "-Wgnu-redeclared-enum"
+ compiler_flag: "-Winfinite-recursion"
+ compiler_flag: "-Wliteral-conversion"
+ compiler_flag: "-Wself-assign"
+ compiler_flag: "-Wstring-conversion"
+ compiler_flag: "-Wtautological-overlap-compare"
+ compiler_flag: "-Wunused-comparison"
+ compiler_flag: "-Wvla"
+ cxx_flag: "-Wdeprecated-increment-bool"
+
+ # Clang code-generation flags for performance optimization.
+ compiler_flag: "-faligned-allocation"
+ compiler_flag: "-fnew-alignment=8"
+
+ # Clang defaults to C99 while GCC defaults to C89. GCC plugins are written in
+ # C89 and don't have a BUILD rule we could add a copts flag to.
+ gcc_plugin_compiler_flag: "-std=gnu89"
+
+ compilation_mode_flags {
+ mode: FASTBUILD
+ }
+
+ compilation_mode_flags {
+ mode: DBG
+ compiler_flag: "-g"
+ }
+
+ compilation_mode_flags {
+ mode: OPT
+ compiler_flag: "-g0"
+ compiler_flag: "-fdebug-types-section"
+ compiler_flag: "-DNDEBUG"
+ compiler_flag: "-fno-split-dwarf-inlining"
+ compiler_flag: "-Os"
+ compiler_flag: "-fexperimental-new-pass-manager"
+ compiler_flag: "-fdebug-info-for-profiling"
+ compiler_flag: "-ffunction-sections"
+ compiler_flag: "-fdata-sections"
+ linker_flag: "-Wl,--gc-sections"
+ linker_flag: "-Wl,-z,relro,-z,now"
+ }
+
+ # Features indicating whether this is a host compile or not. Exactly one of
+ # these will be implicitly provided by bazel.
+ feature { name: "host" }
+ feature { name: "nonhost" }
+
+ # Features indicating which compiler will be used for code generation.
+ feature {
+ name: "llvm_codegen"
+ provides: "codegen"
+ enabled: true
+ }
+
+ # Features for compilation modes. Exactly one of these will be implicitly
+ # provided by bazel.
+ feature { name: "fastbuild" }
+ feature { name: "dbg" }
+ feature { name: "opt" }
+
+ # Features controlling the C++ language mode.
+ feature {
+ name: "c++11"
+ provides: "c++std"
+ flag_set {
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-nostdinc++"
+ flag: "-std=c++11"
+ flag: "-Wc++14-extensions"
+ flag: "-Wc++2a-extensions"
+ flag: "-Wno-binary-literal"
+ }
+ }
+ }
+ feature {
+ name: "c++14"
+ provides: "c++std"
+ flag_set {
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-nostdinc++"
+ flag: "-std=c++14"
+ flag: "-Wc++11-compat"
+ flag: "-Wno-c++11-compat-binary-literal"
+ flag: "-Wc++2a-extensions"
+ }
+ }
+ }
+ feature {
+ name: "c++17"
+ provides: "c++std"
+ flag_set {
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-nostdinc++"
+ flag: "-std=c++17"
+ flag: "-Wc++11-compat"
+ flag: "-Wno-c++11-compat-binary-literal"
+ flag: "-Wc++2a-extensions"
+ }
+ }
+ }
+ feature {
+ name: "c++2a"
+ provides: "c++std"
+ flag_set {
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-nostdinc++"
+ flag: "-std=c++2a"
+ flag: "-Wc++11-compat"
+ flag: "-Wno-c++11-compat-binary-literal"
+ }
+ }
+ }
+ feature {
+ name: "c++default"
+ enabled: true
+ flag_set {
+ # Provide the c++11 flags if no standard is selected
+ with_feature {
+ not_feature: "c++11"
+ not_feature: "c++14"
+ not_feature: "c++17"
+ not_feature: "c++2a"
+ }
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-nostdinc++"
+ flag: "-std=c++11"
+ flag: "-Wc++14-extensions"
+ flag: "-Wc++2a-extensions"
+ flag: "-Wno-binary-literal"
+ }
+ }
+ }
+
+ feature {
+ name: "use_compiler_rt"
+ requires { feature: "llvm_codegen" }
+ # TODO(saugustine): At the moment, "use_compiler_rt" also
+ # requires "linking_mode_flags { mode: FULLY_STATIC" ... },
+ # but that isn't a feature. We should probably convert it.
+ flag_set {
+ action: "c++-link"
+ action: "c++-link-interface-dynamic-library"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-executable"
+ # "link" is a misnomer for these actions. They are really just
+ # invocations of ar.
+ #action: "c++-link-pic-static-library"
+ #action: "c++-link-static-library"
+ #action: "c++-link-alwayslink-static-library"
+ #action: "c++-link-pic-static-library"
+ #action: "c++-link-alwayslink-pic-static-library"
+ flag_group {
+ flag: "-rtlib=compiler-rt"
+ flag: "-lunwind"
+ }
+ }
+ }
+
+ feature {
+ name: "pie"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "cc-flags-make-variable"
+ action: "lto-backend"
+ action: "linkstamp-compile"
+ flag_group {
+ flag: "-mpie-copy-relocations"
+ flag: "-fPIE"
+ }
+ }
+ flag_set {
+ action: "cc-flags-make-variable"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ # Pic must appear after pie, because pic may need to override pie, and bazel
+ # turns it on selectively. These don't interact with other options.
+ #
+ # TODO: In practice, normal vs pic vs pie is a ternary mode. We should
+ # implement it that way. This will require changes to bazel, which only
+ # calculates whether or not pic is needed, not pie.
+ #
+ # NOTE: Bazel might make this all a moot point.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-codegen"
+ action: "c++-module-compile"
+ action: "linkstamp-compile"
+ expand_if_all_available: "pic"
+ flag_group {
+ flag: "-fPIC"
+ }
+ }
+ }
+
+ feature {
+ name: "gold"
+ enabled: true
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-interface-dynamic-library"
+ flag_group {
+ expand_if_none_available: "lto"
+ flag: "-fuse-ld=gold"
+ }
+ }
+ }
+
+ # This is great if you want linking TensorFlow to take ten minutes.
+ feature {
+ name: "lto"
+ requires { feature: "nonhost" }
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-flto=thin"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-interface-dynamic-library"
+ flag_group {
+ flag: "-flto=thin"
+ }
+ }
+ }
+
+ feature {
+ name: "parse_headers"
+ flag_set {
+ action: "c++-header-parsing"
+ flag_group {
+ flag: "-xc++-header"
+ flag: "-fsyntax-only"
+ }
+ }
+ }
+
+ feature {
+ name: "preprocess_headers"
+ flag_set {
+ action: "c++-header-preprocessing"
+ flag_group {
+ flag: "-xc++"
+ flag: "-E"
+ }
+ }
+ }
+
+ feature {
+ name: "per_object_debug_info"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-codegen"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "lto-backend"
+ flag_group {
+ flag: "-gsplit-dwarf"
+ flag: "-ggnu-pubnames"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-interface-dynamic-library"
+ flag_group {
+ expand_if_all_available: "is_using_fission"
+ flag: "-Wl,--gdb-index"
+ }
+ }
+ }
+
+ feature {
+ name: "xray"
+ requires {
+ feature: "llvm_codegen"
+ feature: "nonhost"
+ }
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "c++-link-interface-dynamic-library"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-fxray-instrument"
+ }
+ }
+ }
+
+ feature {
+ name: "minimal_ubsan"
+ requires { feature: "llvm_codegen" }
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ flag_group {
+ flag: "-fsanitize=return,returns-nonnull-attribute,vla-bound,unreachable,float-cast-overflow"
+ flag: "-fsanitize-trap=all"
+ flag: "-DUNDEFINED_BEHAVIOR_SANITIZER"
+ }
+ }
+ }
+
+ feature {
+ name: "minimal_ubsan_enabled_by_default"
+ requires {
+ feature: "llvm_codegen"
+ feature: "fastbuild"
+ }
+ enabled: true
+ implies: "minimal_ubsan"
+ }
+
+ cxx_builtin_include_directory: "%package(@local_config_clang6//clang6)%/llvm/lib/clang/6.0.0/include"
+ cxx_builtin_include_directory: "/usr/include"
+
+ unfiltered_cxx_flag: "-cxx-isystem"
+ unfiltered_cxx_flag: "/usr/include/c++/4.8"
+ unfiltered_cxx_flag: "-cxx-isystem"
+ unfiltered_cxx_flag: "/usr/include/x86_64-linux-gnu/c++/4.8"
+ unfiltered_cxx_flag: "-isystem"
+ unfiltered_cxx_flag: "%package(@local_config_clang6//clang6)%/llvm/lib/clang/6.0.0/include"
+ unfiltered_cxx_flag: "-isystem"
+ unfiltered_cxx_flag: "/usr/include/x86_64-linux-gnu"
+ unfiltered_cxx_flag: "-isystem"
+ unfiltered_cxx_flag: "/usr/include"
+
+ linker_flag: "-Wl,--build-id=md5"
+ linker_flag: "-Wl,--fatal-warnings"
+ linker_flag: "-Wl,--hash-style=gnu"
+ linker_flag: "-no-canonical-prefixes"
+ linker_flag: "--target=x86_64-unknown-linux-gnu"
+
+ linker_flag: "-L/usr/lib/gcc/x86_64-linux-gnu/4.8"
+
+ # This is the minimum x86 architecture TensorFlow supports.
+ compiler_flag: "-DARCH_K8"
+ compiler_flag: "-m64"
+
+ # These are for Linux.
+ ld_embed_flag: "-melf_x86_64"
+ linker_flag: "-Wl,--eh-frame-hdr"
+ linker_flag: "-Wl,-z,max-page-size=0x1000"
+
+ # Google never uses the stack like a heap, e.g. alloca(), because tcmalloc
+ # and jemalloc are so fast. However copts=["$(STACK_FRAME_UNLIMITED)"] can be
+ # specified when that can't be the case.
+ make_variable {
+ name: "STACK_FRAME_UNLIMITED"
+ value: "-Wframe-larger-than=100000000 -Wno-vla"
+ }
+
+ # These flags are for folks who build C/C++ code inside genrules.
+ make_variable {
+ name: "CC_FLAGS"
+ value: "-no-canonical-prefixes --target=x86_64-unknown-linux-gnu -fno-omit-frame-pointer -fno-tree-vrp -msse3"
+ }
+
+ feature {
+ name: "copts"
+ flag_set {
+ expand_if_all_available: "copts"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-header-preprocessing"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "lto-backend"
+ flag_group {
+ iterate_over: "copts"
+ flag: "%{copts}"
+ }
+ }
+ }
+
+ # Please do not statically link libstdc++. This would probably lead to a lot
+ # of bloat since OpKernels need to use linkstatic=1 because b/27630669 and
+ # it could cause memory leaks since Python uses dlopen() on our libraries:
+ # https://stackoverflow.com/a/35015415
+ linker_flag: "-lstdc++"
+ linker_flag: "-lm"
+ linker_flag: "-lpthread"
+ linker_flag: "-l:/lib/x86_64-linux-gnu/libc-2.19.so"
+}
diff --git a/third_party/toolchains/clang6/README.md b/third_party/toolchains/clang6/README.md
new file mode 100644
index 0000000000..0c6be25a0e
--- /dev/null
+++ b/third_party/toolchains/clang6/README.md
@@ -0,0 +1,101 @@
+# TensorFlow Bazel Clang
+
+This is a specialized toolchain that uses an old Debian with a new Clang that
+can cross compile to any x86_64 microarchitecture. It's intended to build Linux
+binaries that only require the following ABIs:
+
+- GLIBC_2.18
+- CXXABI_1.3.7 (GCC 4.8.3)
+- GCC_4.2.0
+
+Which are available on at least the following Linux platforms:
+
+- Ubuntu 14+
+- CentOS 7+
+- Debian 8+
+- SuSE 13.2+
+- Mint 17.3+
+- Manjaro 0.8.11
+
+# System Install
+
+On Debian 8 (Jessie) Clang 6.0 can be installed as follows:
+
+```sh
+cat >>/etc/apt/sources.list <<'EOF'
+deb http://apt.llvm.org/jessie/ llvm-toolchain-jessie main
+deb-src http://apt.llvm.org/jessie/ llvm-toolchain-jessie main
+EOF
+wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add -
+apt-key fingerprint |& grep '6084 F3CF 814B 57C1 CF12 EFD5 15CF 4D18 AF4F 7421'
+apt-get update
+apt-get install clang lld
+```
+
+# Bazel Configuration
+
+This toolchain can compile TensorFlow in 2m30s on a 96-core Skylake GCE VM if
+the following `.bazelrc` settings are added:
+
+```
+startup --host_jvm_args=-Xmx30G
+startup --host_jvm_args=-Xms30G
+startup --host_jvm_args=-XX:MaxNewSize=3g
+startup --host_jvm_args=-XX:-UseAdaptiveSizePolicy
+startup --host_jvm_args=-XX:+UseConcMarkSweepGC
+startup --host_jvm_args=-XX:TargetSurvivorRatio=70
+startup --host_jvm_args=-XX:SurvivorRatio=6
+startup --host_jvm_args=-XX:+UseCMSInitiatingOccupancyOnly
+startup --host_jvm_args=-XX:CMSFullGCsBeforeCompaction=1
+startup --host_jvm_args=-XX:CMSInitiatingOccupancyFraction=75
+
+build --jobs=100
+build --local_resources=200000,100,100
+build --crosstool_top=@local_config_clang6//clang6
+build --noexperimental_check_output_files
+build --nostamp
+build --config=opt
+build --noexperimental_check_output_files
+build --copt=-march=native
+build --host_copt=-march=native
+```
+
+# x86_64 Microarchitectures
+
+## Intel CPU Line
+
+- 2003 P6 M SSE SSE2
+- 2004 prescott SSE3 SSSE3 (-march=prescott)
+- 2006 core X64 SSE4.1 (only on 45nm variety) (-march=core2)
+- 2008 nehalem SSE4.2 VT-x VT-d (-march=nehalem)
+- 2010 westmere CLMUL AES (-march=westmere)
+- 2012 sandybridge AVX TXT (-march=sandybridge)
+- 2012 ivybridge F16C MOVBE (-march=ivybridge)
+- 2013 haswell AVX2 TSX BMI2 FMA (-march=haswell)
+- 2014 broadwell RDSEED ADCX PREFETCHW (-march=broadwell - works on trusty gcc4.9)
+- 2015 skylake SGX ADX MPX AVX-512[xeon-only] (-march=skylake / -march=skylake-avx512 - needs gcc7)
+- 2018 cannonlake AVX-512 SHA (-march=cannonlake - needs clang5)
+
+## Intel Low Power CPU Line
+
+- 2013 silvermont SSE4.1 SSE4.2 VT-x (-march=silvermont)
+- 2016 goldmont SHA (-march=goldmont - needs clang5)
+
+## AMD CPU Line
+
+- 2003 k8 SSE SSE2 (-march=k8)
+- 2005 k8 (Venus) SSE3 (-march=k8-sse3)
+- 2008 barcelona SSE4a?! (-march=barcelona)
+- 2011 bulldozer SSE4.1 SSE4.2 CLMUL AVX AES FMA4?! (-march=bdver1)
+- 2011 piledriver FMA (-march=bdver2)
+- 2015 excavator AVX2 BMI2 MOVBE (-march=bdver4)
+
+## Google Compute Engine Supported CPUs
+
+- 2012 sandybridge 2.6gHz -march=sandybridge
+- 2012 ivybridge 2.5gHz -march=ivybridge
+- 2013 haswell 2.3gHz -march=haswell
+- 2014 broadwell 2.2gHz -march=broadwell
+- 2015 skylake 2.0gHz -march=skylake-avx512
+
+See: <https://cloud.google.com/compute/docs/cpu-platforms>
diff --git a/third_party/toolchains/clang6/clang.BUILD b/third_party/toolchains/clang6/clang.BUILD
new file mode 100644
index 0000000000..802d62c17c
--- /dev/null
+++ b/third_party/toolchains/clang6/clang.BUILD
@@ -0,0 +1,162 @@
+package(default_visibility = ["//visibility:public"])
+
+# Please note that the output of these tools is unencumbered.
+licenses(["restricted"]) # NCSA, GPLv3 (e.g. gold)
+
+filegroup(
+ name = "ar",
+ srcs = ["llvm/bin/llvm-ar"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "as",
+ srcs = ["llvm/bin/llvm-as"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "cpp",
+ srcs = ["llvm/bin/llvm-cpp"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "dwp",
+ srcs = ["llvm/bin/llvm-dwp"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "gcc",
+ srcs = ["llvm/bin/clang"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "gcov",
+ srcs = ["llvm/bin/llvm-cov"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "ld",
+ srcs = ["llvm/bin/ld.lld"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "nm",
+ srcs = ["llvm/bin/llvm-nm"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "objcopy",
+ srcs = ["llvm/bin/llvm-objcopy"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "objdump",
+ srcs = ["llvm/bin/llvm-objdump"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "profdata",
+ srcs = ["llvm/bin/llvm-profdata"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "strip",
+ srcs = ["sbin/strip"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "xray",
+ srcs = ["llvm/bin/llvm-xray"],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "includes",
+ srcs = glob(["llvm/lib/clang/6.0.0/include/**"]),
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "libraries",
+ srcs = glob([
+ "lib/*.*",
+ "lib/clang/6.0.0/lib/linux/*.*",
+ ]),
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "compiler_files",
+ srcs = [
+ ":as",
+ ":gcc",
+ ":includes",
+ ],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "linker_files",
+ srcs = [
+ ":ar",
+ ":ld",
+ ":libraries",
+ ],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = [
+ ":compiler_files",
+ ":dwp",
+ ":gcov",
+ ":linker_files",
+ ":nm",
+ ":objcopy",
+ ":objdump",
+ ":profdata",
+ ":strip",
+ ":xray",
+ ],
+ output_licenses = ["unencumbered"],
+)
+
+filegroup(
+ name = "empty",
+ srcs = [], # bazel crashes without this
+ output_licenses = ["unencumbered"],
+)
+
+cc_toolchain_suite(
+ name = "clang6",
+ toolchains = {
+ "k8|clang6": ":clang6-k8",
+ },
+)
+
+cc_toolchain(
+ name = "clang6-k8",
+ all_files = ":all_files",
+ compiler_files = ":compiler_files",
+ cpu = "k8",
+ dwp_files = ":dwp",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":linker_files",
+ objcopy_files = ":objcopy",
+ output_licenses = ["unencumbered"],
+ static_runtime_libs = [":empty"],
+ strip_files = ":strip",
+ supports_param_files = 1,
+)
diff --git a/third_party/toolchains/clang6/repo.bzl b/third_party/toolchains/clang6/repo.bzl
new file mode 100644
index 0000000000..b81f44506f
--- /dev/null
+++ b/third_party/toolchains/clang6/repo.bzl
@@ -0,0 +1,30 @@
+"""Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds."""
+
+def _clang6_configure(ctx):
+ # TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
+ # method to generate a gigantic CROSSTOOL file that allows
+ # Clang to support everything.
+ ctx.symlink(
+ ctx.os.environ.get('TF_LLVM_PATH',
+ '/usr/lib/llvm-6.0'),
+ 'clang6/llvm')
+ ctx.symlink(
+ ctx.os.environ.get('STRIP', '/usr/bin/strip'),
+ 'clang6/sbin/strip')
+ ctx.symlink(
+ ctx.os.environ.get('OBJDUMP', '/usr/bin/objdump'),
+ 'clang6/sbin/objdump')
+ ctx.symlink(ctx.attr._build, 'clang6/BUILD')
+ ctx.template('clang6/CROSSTOOL', ctx.attr._crosstool, {
+ '%package(@local_config_clang6//clang6)%': str(ctx.path('clang6')),
+ })
+
+clang6_configure = repository_rule(
+ implementation = _clang6_configure,
+ attrs = {
+ '_build': attr.label(
+ default=str(Label('//third_party/toolchains/clang6:clang.BUILD'))),
+ '_crosstool': attr.label(
+ default=str(Label('//third_party/toolchains/clang6:CROSSTOOL.tpl'))),
+ },
+)