aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--SECURITY.md5
-rw-r--r--tensorflow/c/eager/BUILD27
-rw-r--r--tensorflow/c/eager/c_api.cc4
-rw-r--r--tensorflow/c/eager/c_api.h39
-rw-r--r--tensorflow/c/eager/c_api_debug.cc167
-rw-r--r--tensorflow/c/eager/c_api_debug_test.cc50
-rw-r--r--tensorflow/c/eager/c_api_internal.h8
-rw-r--r--tensorflow/c/eager/c_api_test.cc125
-rw-r--r--tensorflow/c/eager/c_api_test_util.cc163
-rw-r--r--tensorflow/c/eager/c_api_test_util.h53
-rw-r--r--tensorflow/c/eager/tape.h37
-rw-r--r--tensorflow/compiler/jit/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_device.cc36
-rw-r--r--tensorflow/compiler/jit/xla_device.h21
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc80
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h9
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.cc81
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h24
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc3
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc8
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h11
-rw-r--r--tensorflow/compiler/tests/depthwise_conv_op_test.py4
-rw-r--r--tensorflow/compiler/tests/eager_test.py24
-rw-r--r--tensorflow/compiler/tests/jit_test.py2
-rw-r--r--tensorflow/compiler/tests/variable_ops_test.py19
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/no_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc65
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h7
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc54
-rw-r--r--tensorflow/compiler/xla/BUILD31
-rw-r--r--tensorflow/compiler/xla/client/client.cc4
-rw-r--r--tensorflow/compiler/xla/client/client.h4
-rw-r--r--tensorflow/compiler/xla/client/local_client.h9
-rw-r--r--tensorflow/compiler/xla/layout_util.cc10
-rw-r--r--tensorflow/compiler/xla/layout_util.h4
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc18
-rw-r--r--tensorflow/compiler/xla/literal_util.cc41
-rw-r--r--tensorflow/compiler/xla/literal_util.h6
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc30
-rw-r--r--tensorflow/compiler/xla/reference_util.h37
-rw-r--r--tensorflow/compiler/xla/scanner.cc197
-rw-r--r--tensorflow/compiler/xla/scanner.h102
-rw-r--r--tensorflow/compiler/xla/scanner_test.cc124
-rw-r--r--tensorflow/compiler/xla/service/BUILD5
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc35
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc21
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc6
-rw-r--r--tensorflow/compiler/xla/service/compiler.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc657
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD31
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc121
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_options.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_options.h33
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_manager.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc20
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc248
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc347
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h42
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc151
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h44
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc60
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc522
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h52
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc321
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc76
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h21
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc24
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h8
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc27
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h16
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc73
-rw-r--r--tensorflow/compiler/xla/shape_util.cc17
-rw-r--r--tensorflow/compiler/xla/shape_util.h26
-rw-r--r--tensorflow/compiler/xla/tests/BUILD52
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc107
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc18
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc23
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.h4
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc71
-rw-r--r--tensorflow/compiler/xla/util.h22
-rw-r--r--tensorflow/contrib/android/jni/run_stats_jni.cc4
-rw-r--r--tensorflow/contrib/autograph/CONTRIBUTING.md5
-rw-r--r--tensorflow/contrib/autograph/STYLE_GUIDE.md76
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py49
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py3
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py55
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py13
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py48
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py4
-rw-r--r--tensorflow/contrib/batching/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc168
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py281
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py147
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc74
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py10
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py3
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py118
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py167
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py9
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD21
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py7
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py9
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake6
-rw-r--r--tensorflow/contrib/data/__init__.py1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py167
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py110
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py45
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py5
-rw-r--r--tensorflow/contrib/distribute/python/BUILD37
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py157
-rw-r--r--tensorflow/contrib/distribute/python/combinations_test.py24
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py41
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py135
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py45
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils_test.py152
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py148
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py13
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py3
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py160
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py30
-rw-r--r--tensorflow/contrib/distributions/python/ops/binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/chi2.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py13
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/geometric.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/negative_binomial.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/statistical_testing.py274
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/BUILD39
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py382
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py162
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py86
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb620
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb474
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb443
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py37
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py10
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py10
-rw-r--r--tensorflow/contrib/eager/python/saver_test.py45
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py21
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py30
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py112
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py71
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py252
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator.py29
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py84
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py29
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc43
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD86
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc330
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc241
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h73
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h6861
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h220
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h296
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc102
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc227
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h3
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc19
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.h104
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h1
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc30
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc158
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h8
-rw-r--r--tensorflow/contrib/lite/profiling/BUILD27
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.cc140
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.h58
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer_test.cc116
-rw-r--r--tensorflow/contrib/lite/python/BUILD27
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py162
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py284
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py106
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/BUILD4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc2
-rw-r--r--tensorflow/contrib/lite/python/lite.py192
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py323
-rw-r--r--tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc1
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc20
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h2
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc32
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md191
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc2
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD4
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.h2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc3
-rw-r--r--tensorflow/contrib/opt/python/training/adamax_test.py8
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum_test.py11
-rw-r--r--tensorflow/contrib/signal/BUILD2
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer.cc153
-rw-r--r--tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc6
-rw-r--r--tensorflow/contrib/tensorrt/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc4
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc2
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto2
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py151
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py234
-rw-r--r--tensorflow/core/BUILD373
-rw-r--r--tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt11
-rw-r--r--tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CollectiveBcastRecv.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CollectiveBcastSend.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CollectiveReduce.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc27
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc6
-rw-r--r--tensorflow/core/common_runtime/executor.cc18
-rw-r--r--tensorflow/core/common_runtime/executor.h1
-rw-r--r--tensorflow/core/common_runtime/executor_test.cc (renamed from tensorflow/core/distributed_runtime/executor_test.cc)0
-rw-r--r--tensorflow/core/common_runtime/function.cc3
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc3
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h5
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator_mgr.cc12
-rw-r--r--tensorflow/core/common_runtime/scoped_allocator_mgr.h11
-rw-r--r--tensorflow/core/common_runtime/testlib_ops.cc (renamed from tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc)20
-rw-r--r--tensorflow/core/distributed_runtime/BUILD19
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD16
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc3
-rw-r--r--tensorflow/core/framework/device_base.h11
-rw-r--r--tensorflow/core/framework/function.h2
-rw-r--r--tensorflow/core/framework/op.cc4
-rw-r--r--tensorflow/core/framework/op_kernel.cc54
-rw-r--r--tensorflow/core/framework/resource_handle.cc25
-rw-r--r--tensorflow/core/framework/resource_handle.h9
-rw-r--r--tensorflow/core/framework/resource_mgr.cc18
-rw-r--r--tensorflow/core/framework/resource_mgr.h53
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc27
-rw-r--r--tensorflow/core/framework/tensor.cc13
-rw-r--r--tensorflow/core/framework/tensor.h1
-rw-r--r--tensorflow/core/framework/variant.cc33
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h10
-rw-r--r--tensorflow/core/graph/graph.cc3
-rw-r--r--tensorflow/core/graph/graph.h2
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc2
-rw-r--r--tensorflow/core/grappler/op_types.cc12
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD46
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc447
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h84
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD76
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc217
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h81
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc142
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc133
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc184
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc87
-rw-r--r--tensorflow/core/grappler/optimizers/remapper_test.cc37
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc929
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h107
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc243
-rw-r--r--tensorflow/core/grappler/utils/functions.cc5
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc12
-rw-r--r--tensorflow/core/kernels/BUILD3
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc41
-rw-r--r--tensorflow/core/kernels/cwise_op_not_equal_to_1.cc23
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc74
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/function_ops.cc1
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc14
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc38
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc62
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h33
-rw-r--r--tensorflow/core/kernels/scoped_allocator_ops_test.cc41
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt171
-rw-r--r--tensorflow/core/ops/dataset_ops.cc6
-rw-r--r--tensorflow/core/ops/ops.pbtxt55
-rw-r--r--tensorflow/core/ops/scoped_allocator_ops.cc37
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc107
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.h18
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request_test.cc76
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc3
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc16
-rw-r--r--tensorflow/core/platform/default/build_config.bzl10
-rw-r--r--tensorflow/core/platform/default/string_coding.cc30
-rw-r--r--tensorflow/core/platform/default/string_coding.h98
-rw-r--r--tensorflow/core/platform/env.cc4
-rw-r--r--tensorflow/core/platform/tensor_coding.cc36
-rw-r--r--tensorflow/core/platform/tensor_coding.h10
-rw-r--r--tensorflow/core/platform/variant_coding.cc71
-rw-r--r--tensorflow/core/platform/variant_coding.h40
-rw-r--r--tensorflow/core/platform/windows/env.cc7
-rw-r--r--tensorflow/core/platform/windows/wide_char.h46
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc1
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.h19
-rw-r--r--tensorflow/core/protobuf/config.proto25
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto10
-rw-r--r--tensorflow/core/protobuf/worker.proto8
-rw-r--r--tensorflow/core/util/stat_summarizer.cc300
-rw-r--r--tensorflow/core/util/stat_summarizer.h191
-rw-r--r--tensorflow/core/util/stat_summarizer_options.h43
-rw-r--r--tensorflow/core/util/stats_calculator.cc289
-rw-r--r--tensorflow/core/util/stats_calculator.h189
-rw-r--r--tensorflow/docs_src/community/security.md7
-rw-r--r--tensorflow/docs_src/get_started/datasets_quickstart.md4
-rw-r--r--tensorflow/docs_src/get_started/get_started_for_beginners.md751
-rw-r--r--tensorflow/docs_src/get_started/index.md22
-rw-r--r--tensorflow/docs_src/get_started/leftnav_files11
-rw-r--r--tensorflow/docs_src/install/install_mac.md6
-rw-r--r--tensorflow/docs_src/install/install_windows.md6
-rw-r--r--tensorflow/docs_src/programmers_guide/checkpoints.md (renamed from tensorflow/docs_src/get_started/checkpoints.md)0
-rw-r--r--tensorflow/docs_src/programmers_guide/custom_estimators.md (renamed from tensorflow/docs_src/get_started/custom_estimators.md)12
-rw-r--r--tensorflow/docs_src/programmers_guide/estimators.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/feature_columns.md (renamed from tensorflow/docs_src/get_started/feature_columns.md)2
-rw-r--r--tensorflow/docs_src/programmers_guide/index.md24
-rw-r--r--tensorflow/docs_src/programmers_guide/leftnav_files16
-rw-r--r--tensorflow/docs_src/programmers_guide/low_level_intro.md6
-rw-r--r--tensorflow/docs_src/programmers_guide/premade_estimators.md (renamed from tensorflow/docs_src/get_started/premade_estimators.md)15
-rw-r--r--tensorflow/docs_src/programmers_guide/using_tpu.md4
-rw-r--r--tensorflow/docs_src/tutorials/kernel_methods.md2
-rw-r--r--tensorflow/docs_src/tutorials/layers.md10
-rw-r--r--tensorflow/docs_src/tutorials/linear.md2
-rw-r--r--tensorflow/docs_src/tutorials/recurrent_quickdraw.md2
-rw-r--r--tensorflow/go/op/wrappers.go598
-rw-r--r--tensorflow/java/BUILD10
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h22
-rw-r--r--tensorflow/java/src/gen/cc/op_gen_main.cc10
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc179
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h7
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc14
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.h14
-rw-r--r--tensorflow/java/src/gen/cc/source_writer_test.cc99
-rw-r--r--tensorflow/python/BUILD55
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py36
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py4
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/debug/lib/grpc_debug_test_server.py13
-rw-r--r--tensorflow/python/debug/lib/source_remote.py23
-rw-r--r--tensorflow/python/debug/lib/source_remote_test.py46
-rw-r--r--tensorflow/python/eager/BUILD3
-rw-r--r--tensorflow/python/eager/backprop.py45
-rw-r--r--tensorflow/python/eager/function.py21
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc181
-rw-r--r--tensorflow/python/eager/pywrap_tfe.h32
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc166
-rw-r--r--tensorflow/python/eager/tensor_test.py5
-rw-r--r--tensorflow/python/estimator/estimator.py49
-rw-r--r--tensorflow/python/estimator/estimator_test.py1
-rw-r--r--tensorflow/python/estimator/exporter.py4
-rw-r--r--tensorflow/python/estimator/exporter_test.py8
-rw-r--r--tensorflow/python/estimator/keras.py40
-rw-r--r--tensorflow/python/estimator/training.py6
-rw-r--r--tensorflow/python/framework/dtypes.py28
-rw-r--r--tensorflow/python/framework/ops.py32
-rw-r--r--tensorflow/python/framework/ops_test.py26
-rw-r--r--tensorflow/python/framework/tensor_shape.py5
-rw-r--r--tensorflow/python/framework/test_util.py20
-rw-r--r--tensorflow/python/framework/test_util_test.py5
-rwxr-xr-xtensorflow/python/keras/BUILD1
-rw-r--r--tensorflow/python/keras/engine/network.py14
-rw-r--r--tensorflow/python/keras/engine/training.py91
-rw-r--r--tensorflow/python/keras/engine/training_test.py132
-rw-r--r--tensorflow/python/keras/engine/training_utils.py29
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent.py6
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/accumulate_n_eager_test.py7
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py38
-rw-r--r--tensorflow/python/kernel_tests/distributions/bijector_test.py32
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py5
-rw-r--r--tensorflow/python/lib/core/py_exception_registry.cc4
-rw-r--r--tensorflow/python/lib/core/py_func.cc4
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc13
-rw-r--r--tensorflow/python/lib/core/py_util.cc3
-rw-r--r--tensorflow/python/lib/core/safe_ptr.h1
-rw-r--r--tensorflow/python/lib/io/file_io.py58
-rw-r--r--tensorflow/python/lib/io/file_io_test.py91
-rw-r--r--tensorflow/python/ops/collective_ops.py133
-rw-r--r--tensorflow/python/ops/collective_ops_test.py80
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py2
-rw-r--r--tensorflow/python/ops/distributions/beta.py4
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py48
-rw-r--r--tensorflow/python/ops/distributions/categorical.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py2
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/distribution.py2
-rw-r--r--tensorflow/python/ops/distributions/exponential.py5
-rw-r--r--tensorflow/python/ops/distributions/gamma.py4
-rw-r--r--tensorflow/python/ops/distributions/laplace.py5
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/normal.py5
-rw-r--r--tensorflow/python/ops/distributions/student_t.py4
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py27
-rw-r--r--tensorflow/python/ops/distributions/uniform.py3
-rw-r--r--tensorflow/python/ops/distributions/util.py1
-rw-r--r--tensorflow/python/ops/gradients_impl.py29
-rw-r--r--tensorflow/python/ops/gradients_test.py48
-rw-r--r--tensorflow/python/ops/image_ops_impl.py89
-rw-r--r--tensorflow/python/ops/image_ops_test.py132
-rw-r--r--tensorflow/python/ops/nn_ops.py19
-rw-r--r--tensorflow/python/ops/variables.py26
-rw-r--r--tensorflow/python/pywrap_tfe.i4
-rw-r--r--tensorflow/python/saved_model/builder_impl.py81
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py147
-rw-r--r--tensorflow/python/training/adam_test.py7
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py20
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py22
-rw-r--r--tensorflow/python/training/checkpointable/BUILD39
-rw-r--r--tensorflow/python/training/checkpointable/base.py8
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py251
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_base.py27
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py219
-rw-r--r--tensorflow/python/training/gradient_descent.py3
-rw-r--r--tensorflow/python/training/gradient_descent_test.py23
-rw-r--r--tensorflow/python/training/momentum_test.py11
-rw-r--r--tensorflow/python/training/session_manager.py8
-rw-r--r--tensorflow/python/training/supervisor.py9
-rw-r--r--tensorflow/python/training/training_util.py20
-rw-r--r--tensorflow/python/training/warm_starting_util.py89
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py41
-rw-r--r--tensorflow/python/util/stat_summarizer.i2
-rw-r--r--tensorflow/python/util/tf_inspect.py128
-rw-r--r--tensorflow/python/util/tf_inspect_test.py46
-rw-r--r--tensorflow/python/util/util.cc20
-rw-r--r--tensorflow/security/advisory/tfsa-2018-001.md34
-rw-r--r--tensorflow/security/advisory/tfsa-2018-002.md33
-rw-r--r--tensorflow/security/advisory/tfsa-2018-003.md48
-rw-r--r--tensorflow/security/advisory/tfsa-2018-004.md35
-rw-r--r--tensorflow/security/advisory/tfsa-2018-005.md36
-rw-r--r--tensorflow/security/advisory/tfsa-2018-006.md35
-rw-r--r--tensorflow/security/index.md18
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc70
-rw-r--r--tensorflow/stream_executor/dnn.h2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt172
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt265
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt99
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt270
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-event.pbtxt180
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt196
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt122
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt173
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt152
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt152
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt148
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt239
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt120
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt150
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt200
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt109
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt197
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt146
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt96
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt126
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt126
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt116
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt180
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-summary.pbtxt230
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt106
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt149
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt86
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt123
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt373
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt288
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt99
-rw-r--r--tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt120
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt180
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt146
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt86
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt126
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt116
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt180
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt230
-rw-r--r--tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt96
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt86
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt87
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt87
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt87
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt110
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt115
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt100
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt110
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt89
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt89
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt99
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt119
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt178
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt98
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt128
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt2
-rw-r--r--tensorflow/tools/api/lib/api_objects.proto7
-rw-r--r--tensorflow/tools/api/lib/python_object_to_proto_visitor.py15
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py20
-rw-r--r--tensorflow/tools/benchmark/benchmark_model.cc10
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_cc_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py2_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_core.sh2
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_py3_core.sh1
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh12
-rwxr-xr-xtensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh1
-rwxr-xr-xtensorflow/tools/dist_test/build_server.sh2
-rwxr-xr-xtensorflow/tools/dist_test/local_test.sh2
-rw-r--r--tensorflow/tools/docs/generate_lib.py25
-rw-r--r--tensorflow/tools/docs/parser.py68
-rw-r--r--tensorflow/tools/docs/pretty_docs.py66
-rw-r--r--tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc2
-rw-r--r--tensorflow/workspace.bzl9
-rw-r--r--third_party/mkl/BUILD4
-rw-r--r--third_party/python_runtime/BUILD (renamed from util/python/BUILD)2
604 files changed, 26928 insertions, 16980 deletions
diff --git a/SECURITY.md b/SECURITY.md
index 01886b613e..0a4be37cbc 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -168,7 +168,7 @@ below).
Please use a descriptive subject line for your report email. After the initial
reply to your report, the security team will endeavor to keep you informed of
-the progress being made towards a fix and announcement.
+the progress being made towards a fix and announcement.
In addition, please include the following information along with your report:
@@ -246,5 +246,8 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
| Type | Versions affected | Reported by | Additional Information |
|--------------------|:-----------------:|-----------------------|-----------------------------|
+| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) |
+| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) |
+| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) |
| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 9ce781fab0..f265da2c2c 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -14,6 +14,7 @@ tf_cuda_library(
name = "c_api",
srcs = [
"c_api.cc",
+ "c_api_debug.cc",
"c_api_internal.h",
],
hdrs = ["c_api.h"],
@@ -45,6 +46,7 @@ tf_cuda_library(
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
+ "//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [
@@ -99,9 +101,31 @@ tf_cuda_library(
],
)
+tf_cuda_library(
+ name = "c_api_test_util",
+ testonly = 1,
+ srcs = ["c_api_test_util.cc"],
+ hdrs = ["c_api_test_util.h"],
+ visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow:__subpackages__",
+ ],
+ deps = [
+ ":c_api",
+ "//tensorflow/c:c_test_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cuda_cc_test(
name = "c_api_test",
- srcs = ["c_api_test.cc"],
+ srcs = [
+ "c_api_debug_test.cc",
+ "c_api_test.cc",
+ ],
extra_copts = tfe_xla_copts(),
tags = [
"guitar",
@@ -109,6 +133,7 @@ tf_cuda_cc_test(
],
deps = [
":c_api",
+ ":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 216210c88c..81221c4078 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -73,10 +73,6 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
-#ifdef TENSORFLOW_EAGER_USE_XLA
-std::atomic_int_fast64_t func_id_generator(0);
-#endif // TENSORFLOW_EAGER_USE_XLA
-
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 574a097e0d..1862af3ce2 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -191,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
TF_Status* status);
+// Debugging/Profiling information for TFE_TensorHandle
+//
+// TFE_TensorDebugInfo contains information useful for debugging and
+// profiling tensors.
+typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
+
+// Retrieves TFE_TensorDebugInfo for `handle`.
+// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller
+// is responsible for deleting returned TFE_TensorDebugInfo.
+// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate
+// error and nullptr is returned. This function can block till the operation
+// that produces `handle` has completed.
+TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
+ TFE_TensorHandle* handle, TF_Status* status);
+
+// Deletes `debug_info`.
+TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
+ TFE_TensorDebugInfo* debug_info);
+
+// Returns the number of dimensions used to represent the tensor on its device.
+// The number of dimensions used to reprensent the tensor on device can be
+// different from the number returned by TFE_TensorHandleNumDims.
+// The return value was current at the time of TFE_TensorDebugInfo creation.
+TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
+ TFE_TensorDebugInfo* debug_info);
+
+// Returns the number of elements in dimension `dim_index`.
+// Tensor representation on device can be transposed from its representation
+// on host. The data contained in dimension `dim_index` on device
+// can correspond to the data contained in another dimension in on-host
+// representation. The dimensions are indexed using the standard TensorFlow
+// major-to-minor order (slowest varying dimension first),
+// not the XLA's minor-to-major order.
+// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns
+// the number of elements in a dimension after padding.
+// The return value was current at the time of TFE_TensorDebugInfo creation.
+TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
+ TFE_TensorDebugInfo* debug_info, int dim_index);
+
// Description of the TensorFlow op to execute.
//
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
new file mode 100644
index 0000000000..5006b76f19
--- /dev/null
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <vector>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
+#ifdef TENSORFLOW_EAGER_USE_XLA
+#include "tensorflow/compiler/jit/xla_device.h"
+#endif // TENSORFLOW_EAGER_USE_XLA
+
+using tensorflow::int64;
+using tensorflow::string;
+
+namespace {
+
+std::vector<int64> TensorShapeAsVector(TFE_TensorHandle* handle,
+ TF_Status* status) {
+ std::vector<int64> shape;
+ int rank = TFE_TensorHandleNumDims(handle, status);
+ if (!status->status.ok()) {
+ return shape;
+ }
+ shape.reserve(rank);
+ for (int i = 0; i < rank; ++i) {
+ shape.push_back(TFE_TensorHandleDim(handle, i, status));
+ if (!status->status.ok()) {
+ return shape;
+ }
+ }
+ return shape;
+}
+
+} // namespace
+
+extern "C" {
+
+TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
+ TFE_TensorHandle* handle, TF_Status* status) {
+ const tensorflow::Tensor* tensor;
+ status->status = handle->handle->Tensor(&tensor);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+
+ tensorflow::Device* device;
+ status->status = handle->handle->Device(&device);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+ // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
+ tensorflow::XlaDevice* xla_device =
+ dynamic_cast<tensorflow::XlaDevice*>(device);
+ if (xla_device != nullptr) {
+ tensorflow::XlaDevice::PaddedShapeFn shape_fn =
+ xla_device->metadata().padded_shape_fn();
+ xla::Shape padded_shape;
+ status->status = shape_fn(*tensor, &padded_shape);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ if (VLOG_IS_ON(3)) {
+ std::vector<int64> shape_to_log = TensorShapeAsVector(handle, status);
+ if (!status->status.ok()) {
+ // Ignore the status here as we are simply logging.
+ status->status = tensorflow::Status::OK();
+ } else {
+ VLOG(3) << "Fully padded shape of ["
+ << tensorflow::str_util::Join(shape_to_log, ", ") << "] is "
+ << padded_shape.DebugString();
+ }
+ }
+
+ if (xla::ShapeUtil::IsTuple(padded_shape)) {
+ if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
+ // Currently, the only case of XlaTensor containing a tuple shape is to
+ // represent 64 bit ints, doubles, and complex numbers (we don't support
+ // 64bit complex numbers).
+ status->status = tensorflow::errors::InvalidArgument(
+ "XlaTensors should only contain tuples of size 2. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+
+ // shape0 is not a const& because we will assign it to padded_shape below.
+ // It is illegal to assign a part of a message to itself.
+ xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
+ const xla::Shape& shape1 =
+ xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
+ if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "XlaTensors should not contain nested tuples. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+ if (!xla::ShapeUtil::Equal(shape0, shape1)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Subshapes of XlaTensors should be the same. Shape: ",
+ padded_shape.DebugString());
+ return nullptr;
+ }
+
+ // Since the only case we handle here are two equal subshapes, we
+ // simply return one of them. The caller will interpret it as this
+ // shape directly storing the 64bit types. This approximation is good
+ // enough for this API's debugging use case.
+ padded_shape = shape0;
+ }
+
+ int rank = padded_shape.dimensions_size();
+ std::vector<int64> dev_dims;
+ dev_dims.reserve(rank);
+ if (rank == 1) {
+ // Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
+ dev_dims.push_back(padded_shape.dimensions(0));
+ } else {
+ for (int i = rank - 1; i >= 0; --i) {
+ int64 dim_index = padded_shape.layout().minor_to_major(i);
+ dev_dims.push_back(padded_shape.dimensions(dim_index));
+ }
+ }
+ status->status = tensorflow::Status::OK();
+ return new TFE_TensorDebugInfo(dev_dims);
+ }
+#endif // TENSORFLOW_EAGER_USE_XLA
+
+ // If the tensor is not an XLA tensor, the device shape is
+ // the same as regular tensor shape.
+ std::vector<int64> dev_dims = TensorShapeAsVector(handle, status);
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ return new TFE_TensorDebugInfo(dev_dims);
+}
+
+TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
+ TFE_TensorDebugInfo* debug_info) {
+ delete debug_info;
+}
+
+TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
+ TFE_TensorDebugInfo* debug_info) {
+ return debug_info->dev_dims.size();
+}
+
+TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
+ TFE_TensorDebugInfo* debug_info, int dim_index) {
+ return debug_info->dev_dims[dim_index];
+}
+
+} // extern "C"
diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc
new file mode 100644
index 0000000000..cddb9f6e00
--- /dev/null
+++ b/tensorflow/c/eager/c_api_debug_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include <string.h>
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+TEST(CApiDebug, ScalarCPU) {
+ TFE_TensorHandle* h = TestScalarTensorHandle();
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
+
+ TFE_DeleteTensorDebugInfo(debug_info);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteStatus(status);
+}
+
+TEST(CApiDebug, 2DCPU) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
+ // Shape is the same for CPU tensors.
+ EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0));
+ EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1));
+
+ TFE_DeleteTensorDebugInfo(debug_info);
+ TFE_DeleteTensorHandle(h);
+ TF_DeleteStatus(status);
+}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 2b8384d720..04a6efc47c 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -107,6 +107,14 @@ struct TFE_TensorHandle {
tensorflow::TensorHandle* handle;
};
+struct TFE_TensorDebugInfo {
+ TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
+ : dev_dims(dims) {}
+
+ // Fully-padded, minor-to-major.
+ std::vector<tensorflow::int64> dev_dims;
+};
+
struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 49646bb735..27ff5f7211 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
+#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -32,122 +33,6 @@ using tensorflow::string;
namespace {
-TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
- int64_t dims[] = {2, 2};
- double data[] = {1.0, 2.0, 3.0, 4.0};
- TF_Tensor* t = TF_AllocateTensor(
- TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_TensorHandle* TestMatrixTensorHandle() {
- int64_t dims[] = {2, 2};
- float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
- TF_Tensor* t = TF_AllocateTensor(
- TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_TensorHandle* TestMatrixTensorHandle3X2() {
- int64_t dims[] = {3, 2};
- double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
- TF_Tensor* t = TF_AllocateTensor(
- TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
- TF_Status* status = TF_NewStatus();
-
- TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, a, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, b, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteStatus(status);
- TFE_OpSetAttrBool(op, "transpose_a", 0);
- TFE_OpSetAttrBool(op, "transpose_b", 0);
- TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
-
- return op;
-}
-
-TFE_TensorHandle* TestAxisTensorHandle() {
- int64_t dims[] = {1};
- int data[] = {1};
- TF_Tensor* t = TF_AllocateTensor(
- TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
- memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
- TF_Status* status = TF_NewStatus();
- TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TF_DeleteTensor(t);
- TF_DeleteStatus(status);
- return th;
-}
-
-TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
- TFE_TensorHandle* axis) {
- TF_Status* status = TF_NewStatus();
-
- TFE_Op* op = TFE_NewOp(ctx, "Min", status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, input, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpAddInput(op, axis, status);
- CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
- TFE_OpSetAttrBool(op, "keep_dims", 1);
- TFE_OpSetAttrType(op, "Tidx", TF_INT32);
- TF_DeleteStatus(status);
- TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
-
- return op;
-}
-
-// If there is a GPU device, returns true and sets 'gpu_device_name'
-// accordingly.
-bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
- std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
- TF_NewStatus(), TF_DeleteStatus);
- TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
- CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
-
- const int num_devices = TF_DeviceListCount(devices);
- for (int i = 0; i < num_devices; ++i) {
- const string device_type(TF_DeviceListType(devices, i, status.get()));
- CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
- const string device_name(TF_DeviceListName(devices, i, status.get()));
- CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
- if (device_type == "GPU") {
- *gpu_device_name = device_name;
- LOG(INFO) << "Found GPU device " << device_name;
- TF_DeleteDeviceList(devices);
- return true;
- }
- }
- TF_DeleteDeviceList(devices);
- return false;
-}
-
void BM_InitOp(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
@@ -536,7 +421,7 @@ void TensorHandleSilentCopy(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@@ -583,7 +468,7 @@ void TensorHandleSilentCopyLocal(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@@ -624,7 +509,7 @@ void SetAndGetOpDevices(bool async) {
// Disable the test if no GPU is present.
string gpu_device_name;
- if (GetGPUDeviceName(ctx, &gpu_device_name)) {
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_OpSetDevice(matmul, "GPU:0", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
const char* device_name = TFE_OpGetDevice(matmul, status);
@@ -688,7 +573,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
- TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2();
+ TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
status);
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
new file mode 100644
index 0000000000..5607c9dcb0
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -0,0 +1,163 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/c_api_test_util.h"
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+using tensorflow::string;
+
+TFE_TensorHandle* TestScalarTensorHandle() {
+ float data[] = {1.0f};
+ TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* TestMatrixTensorHandle() {
+ int64_t dims[] = {2, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
+ int64_t dims[] = {3, 2};
+ double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_TensorHandle* TestMatrixTensorHandle3X2() {
+ int64_t dims[] = {3, 2};
+ float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, b, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrBool(op, "transpose_a", 0);
+ TFE_OpSetAttrBool(op, "transpose_b", 0);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
+TFE_TensorHandle* TestAxisTensorHandle() {
+ int64_t dims[] = {1};
+ int data[] = {1};
+ TF_Tensor* t = TF_AllocateTensor(
+ TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+ memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+ return th;
+}
+
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+ TFE_TensorHandle* axis) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Min", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, input, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, axis, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpSetAttrBool(op, "keep_dims", 1);
+ TFE_OpSetAttrType(op, "Tidx", TF_INT32);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
+
+ return op;
+}
+
+bool GetDeviceName(TFE_Context* ctx, string* device_name,
+ const char* device_type) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
+ CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ const int num_devices = TF_DeviceListCount(devices);
+ for (int i = 0; i < num_devices; ++i) {
+ const string dev_type(TF_DeviceListType(devices, i, status.get()));
+ CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
+ const string dev_name(TF_DeviceListName(devices, i, status.get()));
+ CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
+ if (dev_type == device_type) {
+ *device_name = dev_name;
+ LOG(INFO) << "Found " << device_type << " device " << *device_name;
+ TF_DeleteDeviceList(devices);
+ return true;
+ }
+ }
+ TF_DeleteDeviceList(devices);
+ return false;
+}
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
new file mode 100644
index 0000000000..474cae67c8
--- /dev/null
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
+#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
+
+#include "tensorflow/c/eager/c_api.h"
+
+#include "tensorflow/core/platform/types.h"
+
+// Return a tensor handle containing a float scalar
+TFE_TensorHandle* TestScalarTensorHandle();
+
+// Return a tensor handle containing a 2x2 matrix of doubles
+TFE_TensorHandle* DoubleTestMatrixTensorHandle();
+
+// Return a tensor handle containing a 2x2 matrix of floats
+TFE_TensorHandle* TestMatrixTensorHandle();
+
+// Return a tensor handle containing a 3x2 matrix of doubles
+TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
+
+// Return a tensor handle containing a 3x2 matrix of floats
+TFE_TensorHandle* TestMatrixTensorHandle3X2();
+
+// Return a matmul op multiplying `a` by `b`.
+TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
+
+// Return an 1-D INT32 tensor containing a single value 1.
+TFE_TensorHandle* TestAxisTensorHandle();
+
+// Return an op taking minimum of `input` long `axis` dimension.
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+ TFE_TensorHandle* axis);
+
+// If there is a device of type `device_type`, returns true
+// and sets 'device_name' accordingly.
+// `device_type` must be either "GPU" or "TPU".
+bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
+ const char* device_type);
+
+#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 1833b25fea..734e712daa 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -48,7 +48,7 @@ struct OpTapeEntry {
// Should be called before deleting the backward function. TODO(apassos) use
// unique_ptrs to ensure this happens.
- std::function<void()> backward_function_deleter;
+ std::function<void(BackwardFunction*)> backward_function_deleter;
};
// Map from tensor_id to internally-defined operation-id of the operation which
@@ -110,12 +110,6 @@ class VSpace {
// Deletes the input tensor.
virtual void DeleteGradient(Gradient* gradient) const = 0;
-
- // Lets this VSpace know that it can release resources held by the
- // `backward_function`, It will not be called again.
- // `backward_function` must not be null.
- virtual void ReleaseBackwardFunction(
- BackwardFunction* backward_function) const = 0;
};
// Traces the execution of operations, doing eager garbage collection, and
@@ -130,7 +124,7 @@ class GradientTape {
GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape() {
for (const auto& pair : op_tape_) {
- pair.second.backward_function_deleter();
+ pair.second.backward_function_deleter(pair.second.backward_function);
}
}
@@ -139,12 +133,12 @@ class GradientTape {
void Watch(int64 tensor_id);
- void RecordOperation(const string& op_type,
- gtl::ArraySlice<TapeTensor> output_tensors,
- gtl::ArraySlice<int64> input_tensor_id,
- gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
- const std::function<void()>& backward_function_deleter);
+ void RecordOperation(
+ const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ gtl::ArraySlice<int64> input_tensor_id,
+ gtl::ArraySlice<tensorflow::DataType> input_dtypes,
+ BackwardFunction* backward_function,
+ const std::function<void(BackwardFunction*)>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
@@ -218,9 +212,9 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
- const std::function<void()>& backward_function_deleter) {
+ const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
- backward_function_deleter();
+ backward_function_deleter(backward_function);
return;
}
std::vector<int64> ids;
@@ -275,7 +269,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
for (int64 id : op_it->second.input_tensor_id) {
DeleteTrace(id);
}
- op_it->second.backward_function_deleter();
+ op_it->second.backward_function_deleter(op_it->second.backward_function);
op_tape_.erase(op_it);
}
@@ -381,7 +375,8 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
// backward functions that will be used for gradient computation
// has been transferred to `result`.
for (const auto& op_pair : *op_tape) {
- op_pair.second.backward_function_deleter();
+ op_pair.second.backward_function_deleter(
+ op_pair.second.backward_function);
}
op_tape->clear();
}
@@ -473,7 +468,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
if (!persistent_) {
// Release all backprop functions
for (const auto& pair : state.op_tape) {
- pair.second.backward_function_deleter();
+ pair.second.backward_function_deleter(pair.second.backward_function);
}
}
};
@@ -541,7 +536,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients);
if (!persistent_) {
- vspace.ReleaseBackwardFunction(trace.backward_function);
+ trace.backward_function_deleter(trace.backward_function);
}
if (!s.ok()) {
cleanup();
@@ -550,7 +545,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
} else {
in_gradients.resize(trace.input_tensor_id.size());
if (!persistent_) {
- vspace.ReleaseBackwardFunction(trace.backward_function);
+ trace.backward_function_deleter(trace.backward_function);
}
for (Gradient* grad : out_gradients) {
if (grad != nullptr) {
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 980e0eec9e..6d6c030a26 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -178,6 +178,7 @@ cc_library(
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:variable_ops",
],
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 27287e0f96..902fe27acd 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
XlaCompiler::Options options;
options.client = client;
- options.device_type = &cache->device_type();
+ options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index ab644ff5a6..b1943d3e1a 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile(
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
- DeviceType device_type = metadata.jit_device_type();
- options.device_type = &device_type;
+ options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index ea9e036604..43648402f6 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -50,11 +50,12 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(
- XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
- name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+ DEVICE_CPU_XLA_JIT, options, name_prefix,
+ registration,
+ /*transfer_as_literal=*/false,
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index f13b46c532..ed007d603e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
return alloc_ptr;
}
+namespace {
+
+// Default PaddedShapeFn implementation that simply returns the unpadded
+// on-device shape. This is accurate for CPU and GPU devices that neither
+// transpose nor pad tensors.
+Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
+ const tensorflow::XlaTensor* xla_tensor =
+ tensorflow::XlaTensor::FromTensor(&tensor);
+ if (xla_tensor == nullptr) {
+ return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
+ }
+
+ const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
+ *shape = shaped_buffer.on_device_shape();
+ return Status::OK();
+}
+
+} // namespace
+
/* static */ Status XlaDevice::Create(
const string& platform_name, const string& device_name, int device_ordinal,
const string& jit_device_name, const SessionOptions& options,
@@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device) {
+ const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
+ platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ PaddedShapeFn padded_shape_fn)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ padded_shape_fn_(std::move(padded_shape_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -178,10 +201,11 @@ XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
+ const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+ const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn),
+ shape_representation_fn, padded_shape_fn),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d5d345d43b..02e88ee679 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -45,13 +45,19 @@ namespace tensorflow {
class XlaDevice : public LocalDevice {
public:
+ // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
+ // on device, fully padded. On error, the contents of `xla::Shape*`
+ // are undefined.
+ typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
+
// Wrapper class to store metadata about the XlaDevice, where it can be
// retrieved e.g., when lazily creating the XlaCompilationCache device.
class Metadata {
public:
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ PaddedShapeFn padded_shape_fn);
// The index of the device on this host.
int device_ordinal() const;
@@ -62,12 +68,14 @@ class XlaDevice : public LocalDevice {
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
+ const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ PaddedShapeFn padded_shape_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -81,6 +89,8 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
+ // If padded_shape_fn is empty, a default implementation that returns
+ // the on-host shape is used.
static Status Create(
const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
@@ -88,12 +98,16 @@ class XlaDevice : public LocalDevice {
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device);
+ const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
+ // Creates a new XLA Device.
+ // If padded_shape_fn is empty, a default implementation that returns
+ // the logical on-device shape without padding is used.
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
+ const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+ const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -110,6 +124,7 @@ class XlaDevice : public LocalDevice {
Tensor* tensor) override;
xla::LocalClient* client() const;
+ const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
// If not already set, create and set GpuDeviceInfo.
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index ff30b62bad..71e63b110b 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -54,16 +54,26 @@ XlaTransferManager::XlaTransferManager(
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ shape_representation_fn_(std::move(shape_representation_fn)) {
+ if (!shape_representation_fn_) {
+ shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
+ return shape;
+ };
+ }
+}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
- xla::Literal literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal));
- VLOG(1) << "Transfer to device as literal: " << literal.ToString();
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
+ host_tensor.shape(), &xla_shape));
+ xla::BorrowingLiteral literal(
+ static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+ VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
+ << shaped_buffer.ToString();
return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
shaped_buffer);
}
@@ -76,7 +86,8 @@ Status XlaTransferManager::TransferLiteralFromDevice(
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer));
- VLOG(1) << "Transfer from device as literal: " << literal->ToString();
+ VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
+ << shaped_buffer.ToString();
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
@@ -98,7 +109,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
<< " "
<< reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
- << " " << cpu_tensor->NumElements();
+ << " " << cpu_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
const int64 total_bytes = cpu_tensor->TotalBytes();
@@ -106,13 +119,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
- TensorShape shape;
- if (shape_representation_fn_) {
- shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
- } else {
- shape = device_tensor->shape();
- }
+ TensorShape shape = shape_representation_fn_(device_tensor->shape(),
+ device_tensor->dtype());
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
@@ -165,7 +173,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
device_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << device_tensor->NumElements();
+ << " " << device_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
const int64 total_bytes = cpu_tensor->TotalBytes();
se::DeviceMemoryBase dev_src_ptr =
@@ -194,6 +204,42 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(Status::OK());
}
+void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+ Tensor* dst_tensor,
+ const StatusCallback& done) {
+ // TODO(phawkins): replace this code with an asynchronous implementation.
+ auto body = [&]() {
+ if (src_tensor.NumElements() == 0) {
+ return Status::OK();
+ }
+ XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
+ XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
+ CHECK(xla_src && xla_dst)
+ << "Missing destination tensor for device-to-device copy";
+ if (!xla_dst->has_shaped_buffer()) {
+ TensorShape shape =
+ shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+ TF_RETURN_IF_ERROR(
+ xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
+ stream_->parent()->device_ordinal()));
+ }
+ TF_RETURN_IF_ERROR(
+ xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
+ [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
+ const se::DeviceMemoryBase& from_buffer =
+ xla_src->shaped_buffer().buffers().element(index);
+ CHECK_EQ(buffer->size(), from_buffer.size());
+ if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
+ buffer->size())) {
+ return errors::Internal("Device to device memcpy failed");
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+ };
+ done(body());
+}
+
XlaDeviceContext::XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
@@ -215,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done);
}
+void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+ Tensor* dst_tensor,
+ const StatusCallback& done) {
+ manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 9af9655868..ee346e5653 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -55,6 +55,10 @@ class XlaTransferManager {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
+
+ void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+ const StatusCallback& done);
+
se::Stream* stream() const { return stream_; }
private:
@@ -72,7 +76,7 @@ class XlaTransferManager {
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
- const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -90,6 +94,9 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
+ void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+ const StatusCallback& done);
+
se::Stream* stream() const override { return manager_.stream(); }
private:
diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc
index f68dba6b6a..5ecb1afa7b 100644
--- a/tensorflow/compiler/jit/xla_device_ops.cc
+++ b/tensorflow/compiler/jit/xla_device_ops.cc
@@ -15,7 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/jit/xla_tensor.h"
namespace tensorflow {
@@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
<< type_string() << " on an XLA device. This should never happen.";
}
+XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c)
+ : AsyncOpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+}
+
+void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context,
+ DoneCallback done) {
+ OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(),
+ errors::InvalidArgument(
+ "Variable and value dtypes don't match; respectively, ",
+ dtype_, " and ", context->input(1).dtype()),
+ done);
+ Var* variable = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, context](Var** ptr) {
+ *ptr = new Var(dtype_);
+ PersistentTensor unused;
+ Tensor* tmp;
+ AllocatorAttributes attr;
+ TF_RETURN_IF_ERROR(context->allocate_persistent(
+ dtype_, context->input(1).shape(), &unused, &tmp, attr));
+ *(*ptr)->tensor() = *tmp;
+ return Status::OK();
+ }),
+ done);
+ core::ScopedUnref s(variable);
+
+ OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_,
+ errors::InvalidArgument(
+ "Trying to assign variable with wrong dtype. Expected ",
+ DataTypeString(variable->tensor()->dtype()), " got ",
+ DataTypeString(dtype_)),
+ done);
+
+ const Tensor& value = context->input(1);
+ AllocatorAttributes attr;
+
+ // Copying is unnecessary if we are the last user of the value tensor, we can
+ // just adopt the input tensor's buffer instead.
+ std::unique_ptr<Tensor> input_alias = context->forward_input(
+ 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_,
+ value.shape(), DEVICE_MEMORY, attr);
+ mutex_lock ml(*variable->mu());
+ variable->is_initialized = true;
+ if (input_alias) {
+ *variable->tensor() = *input_alias;
+ done();
+ return;
+ }
+
+ // Need to copy, but maybe we can re-use variable's buffer?
+ if (!XlaTensor::RefCountIsOne(*variable->tensor()) ||
+ !variable->tensor()->shape().IsSameSize(value.shape())) {
+ // Copy to new buffer
+ PersistentTensor unused;
+ Tensor* tmp;
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_persistent(dtype_, value.shape(),
+ &unused, &tmp, attr),
+ done);
+ *variable->tensor() = *tmp;
+ }
+
+ XlaDeviceContext* device_context =
+ static_cast<XlaDeviceContext*>(context->op_device_context());
+
+ variable->Ref();
+ device_context->CopyDeviceTensorToDevice(
+ value, variable->tensor(), [context, variable, done](Status status) {
+ variable->Unref();
+ context->SetStatus(status);
+ done();
+ });
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 9c00a0682c..b27c32e9bc 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/variable_ops.h"
@@ -41,6 +42,15 @@ class XlaDeviceDummyOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
};
+class XlaAssignVariableOp : public AsyncOpKernel {
+ public:
+ explicit XlaAssignVariableOp(OpKernelConstruction* c);
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+ DataType dtype_;
+};
+
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \
.Device(DEVICE) \
@@ -73,7 +83,19 @@ class XlaDeviceDummyOp : public OpKernel {
\
REGISTER_KERNEL_BUILDER( \
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
- ResourceHandleOp<Var>);
+ ResourceHandleOp<Var>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
+ ReadVariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
+ XlaAssignVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
+ ControlTriggerOp); \
+ REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
+ SwitchOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 26842fbe5c..c0d86a28c7 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -49,7 +49,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device);
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4146996f63..661187f4a8 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -48,11 +48,12 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.compile_resource_ops = true;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(XlaDevice::Create(
- "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT,
- options, name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0,
+ DEVICE_INTERPRETER_XLA_JIT, options,
+ name_prefix, registration,
+ /*transfer_as_literal=*/false,
+ /*shape_representation_fn=*/{},
+ /*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index a7211c9c7e..3c44c4ae6d 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
-/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) {
+/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
if (tensor->NumElements() == 0) {
return nullptr;
}
@@ -27,8 +27,8 @@ namespace tensorflow {
return xla_tensor;
}
-/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
- return FromTensor(const_cast<Tensor*>(tensor));
+/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
+ return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
@@ -67,6 +67,8 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
index_to_buffer.second = buffer.Forget();
}
+ VLOG(4) << shaped_buffer.ToString();
+
set_shaped_buffer(std::move(shaped_buffer));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 6b29c82ec1..c54001a999 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -34,10 +34,9 @@ class XlaTensor {
public:
// Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
// fails.
- static XlaTensor* FromTensor(Tensor* tensor);
- // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast
- // fails.
- static const XlaTensor* FromTensor(const Tensor* tensor);
+ static XlaTensor* FromTensor(const Tensor* tensor);
+
+ static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a
@@ -62,6 +61,10 @@ class XlaTensor {
CHECK(has_shaped_buffer());
return *shaped_buffer_;
}
+ xla::ShapedBuffer& shaped_buffer() {
+ CHECK(has_shaped_buffer());
+ return *shaped_buffer_;
+ }
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 0a0d335ca7..03d96a2cd8 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -153,7 +153,7 @@ class DepthwiseConv2DTest(XLATestCase):
dtype=data_type).reshape(filter_in_sizes)
with self.test_session() as sess:
if data_type == np.float32:
- tolerance = 1e-5
+ tolerance = 1e-4
else:
self.assertEqual(data_type, np.float64)
tolerance = 1e-8
@@ -339,7 +339,7 @@ class DepthwiseConv2DTest(XLATestCase):
gpu_value = _GetVal(use_xla=True)
cpu_value = _GetVal(use_xla=False)
- self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3)
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 52d8d6d295..4dff5f0f40 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -117,6 +117,15 @@ class EagerTest(XLATestCase):
v.assign_add(2.0)
self.assertEqual(3.0, v.numpy())
+ def testReadAssignRead(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ val1 = v.read_value()
+ v.assign_add(2.0)
+ val2 = v.read_value()
+ self.assertEqual(1.0, val1.numpy())
+ self.assertEqual(3.0, val2.numpy())
+
def testGradient(self):
def f(x):
return x
@@ -136,6 +145,21 @@ class EagerTest(XLATestCase):
grads = backprop.implicit_grad(f)()
self.assertEqual(2., grads[0][0].numpy())
+ def testMultipleVariableReads(self):
+ # This test makes sure consecutive variable reads don't copy
+ # the underlying memory.
+ with self.test_scope():
+ # Create 128MiB variables
+ var = resource_variable_ops.ResourceVariable(
+ array_ops.ones([32, 1024, 1024]))
+
+ # Read the same variable 100 times. If the underlying tensor
+ # is not copied, this is a trivial operation. If it is copied,
+ # this will eat over 13GB and OOM.
+ values = []
+ for _ in range(100):
+ values.append(var.value())
+
class EagerFunctionTest(XLATestCase):
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 4b0043b6b4..6e0db54b7a 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -125,7 +125,7 @@ class JitLaunchTest(test.TestCase):
for (x, y) in zip(compiled, direct):
self.assertAllClose(x, y, rtol=1e-1)
else:
- self.assertAllClose(compiled, direct)
+ self.assertAllClose(compiled, direct, rtol=1e-2)
def testNoOutputs(self):
with session_lib.Session() as sess:
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index 8ecad00f6e..2c09b03d5a 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -187,6 +187,25 @@ class VariableOpsTest(XLATestCase):
rtol=1e-4)
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
+ def testWriteOfAliasedTensor(self):
+ for dtype in self.numeric_types:
+ init = np.array([[1, 2j], [3, 4]]).astype(dtype)
+ update = np.array([[7, 1j], [2, 11]]).astype(dtype)
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(init)
+ sess.run(variables.variables_initializer([v]))
+ p = array_ops.placeholder(dtype)
+ q = array_ops.identity(p)
+ x = v.read_value()
+ # Writes the value of 'p' to 'v', but keeps a reference to the original
+ # value of 'v' so the variable update cannot reuse its buffer.
+ with ops.control_dependencies([x]):
+ y = v.assign(q)
+ result = sess.run([x, y, q], {p: update})
+ self.assertAllClose(init, result[0])
+ self.assertAllClose(update, result[1])
+ self.assertAllClose(update, result[2])
+
class StridedSliceAssignChecker(object):
"""Compares the results of a slice assignment using Tensorflow and numpy."""
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index b707bd0963..f0b010fa67 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
@@ -46,6 +47,12 @@ class XlaDeviceTest(XLATestCase):
result = sess.run(z, {x: inputs})
self.assertAllCloseAccordingToType(result, inputs + inputs)
+ def testControlTrigger(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = gen_control_flow_ops.control_trigger()
+ sess.run(x)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc
index 8c8a9bbe78..65ab9da8d7 100644
--- a/tensorflow/compiler/tf2xla/kernels/no_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc
@@ -24,8 +24,7 @@ namespace tensorflow {
REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp);
// We register ControlTrigger as a no-op. This is correct since nodes seen
-// by the XLA compiler are never dead. This may need rethinking when we add
-// support for conditionals to XLA.
-REGISTER_XLA_OP(Name("ControlTrigger"), NoOp);
+// by the XLA compiler are never dead.
+REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 6109db8e89..a163fa0a5b 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -57,7 +57,7 @@ class ReadVariableOp : public XlaOpKernel {
private:
DataType dtype_;
};
-REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
+REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp);
class AssignVariableOp : public XlaOpKernel {
public:
@@ -67,7 +67,7 @@ class AssignVariableOp : public XlaOpKernel {
ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
}
};
-REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp);
+REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp);
class AssignAddVariableOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 3a08aa8cf4..ac768b206e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
- DeviceType device_type(DEVICE_CPU_XLA_JIT);
- compiler_options.device_type = &device_type;
+ compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index f7098917b1..ccbc74eb31 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(options),
initialization_status_(Status::OK()),
next_step_id_(1),
- device_(
- new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
+ device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {
- // We no longer need the device_type.
- options_.device_type = nullptr;
-
+ CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
@@ -659,6 +656,59 @@ Status XlaCompiler::CompileSingleOp(
return CompileGraph(options, name, std::move(graph), args, result);
}
+namespace {
+
+// Check that the ops of all non-functional nodes have been registered.
+string ValidateFunctionDef(const FunctionDef* fdef,
+ const FunctionLibraryDefinition& flib_def) {
+ std::vector<string> invalid_ops;
+ for (const NodeDef& node : fdef->node_def()) {
+ const string& op = node.op();
+ if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
+ continue;
+ }
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
+ invalid_ops.push_back(op);
+ }
+ }
+ return tensorflow::str_util::Join(invalid_ops, ", ");
+}
+
+// Check that the graph doesn't have any nodes incompatible with given
+// device_type.
+Status ValidateGraph(const Graph* graph,
+ const FunctionLibraryDefinition& flib_def,
+ const DeviceType& device_type, const string& name) {
+ std::vector<string> invalid_ops;
+ for (const Node* node : graph->nodes()) {
+ if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
+ continue;
+ }
+ const FunctionDef* fdef = flib_def.Find(node->def().op());
+ if (fdef) {
+ string error_msg = ValidateFunctionDef(fdef, flib_def);
+ if (!error_msg.empty()) {
+ invalid_ops.push_back(
+ strings::StrCat(node->def().op(), ":{", error_msg, "}"));
+ }
+ continue;
+ }
+ if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
+ invalid_ops.push_back(node->def().op());
+ }
+ }
+ if (!invalid_ops.empty()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ":",
+ tensorflow::str_util::Join(invalid_ops, ", ")));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name,
std::unique_ptr<Graph> graph,
@@ -681,6 +731,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
graph.get(), local_flib_def_.get()));
+ // Detect ops incompatible with the device_type.
+ // FunctionalizeControlFlow may remove some unsupported ops.
+ TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
+ options_.device_type, name));
+
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index bf496bd8bc..76f4c4c1ea 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -244,9 +245,9 @@ class XlaCompiler {
typedef std::function<TensorShape(const TensorShape&, DataType)>
ShapeRepresentationFn;
struct Options {
- // Name of the compilation device to use. Needs to be live only during
- // XlaCompiler's constructor.
- const DeviceType* device_type = nullptr;
+ // Name of the compilation device to use. It must be set by the caller.
+ // The default empty value is invalid.
+ DeviceType device_type = DeviceType("");
xla::Client* client = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 55772ca324..246b386f38 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -45,8 +45,6 @@ namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
- XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
@@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test {
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
- options.device_type = &cpu_device_type_;
+ options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
@@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test {
return compiler->local_flib_def_.get();
}
- DeviceType cpu_device_type_;
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
@@ -979,5 +976,54 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests a graph which has a function with an invalid op.
+TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
+ XlaCompiler compiler(DefaultOptions());
+
+ FunctionDefLibrary flib;
+ FunctionDef fn = FillFn();
+ NodeDef* node = fn.add_node_def();
+ node->set_name("Invalid");
+ node->set_op("InvalidOp"); /* unsupported op */
+ node = fn.add_node_def();
+ node->set_name("Switch");
+ node->set_op("Switch"); /* control flow node */
+ *flib.add_function() = fn;
+
+ TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
+ auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
+ TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ NodeDef def;
+ TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
+ .Input(value.name(), 0, DT_INT32)
+ .Input(shape.name(), 1, DT_INT32)
+ .Finalize(&def));
+ Status status;
+ Node* fill = scope.graph()->AddNode(def, &status);
+ TF_ASSERT_OK(status);
+ TF_ASSERT_OK(scope.DoShapeInference(fill));
+ scope.graph()->AddEdge(value.node(), 0, fill, 0);
+ scope.graph()->AddEdge(shape.node(), 0, fill, 1);
+
+ auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
+
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
+ std::move(graph), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index c6deb959a5..c08db7e3fb 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -500,6 +500,37 @@ cc_library(
)
cc_library(
+ name = "scanner",
+ srcs = ["scanner.cc"],
+ hdrs = ["scanner.h"],
+ visibility = [":internal"],
+ deps = [
+ ":status",
+ ":status_macros",
+ ":types",
+ ":util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_cc_test(
+ name = "scanner_test",
+ srcs = ["scanner_test.cc"],
+ deps = [
+ ":scanner",
+ ":status",
+ ":status_macros",
+ ":test",
+ ":types",
+ ":util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "text_literal_reader",
srcs = ["text_literal_reader.cc"],
hdrs = ["text_literal_reader.h"],
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index c9d275a77b..3d596a6e65 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -64,7 +64,7 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
}
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
- const Literal& literal, const DeviceHandle* device_handle) {
+ const LiteralSlice& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
@@ -91,7 +91,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
return MakeUnique<GlobalData>(stub_, response.data());
}
-Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
+Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id,
const DeviceHandle* device_handle) {
TransferToInfeedRequest request;
*request.mutable_literal() = literal.ToProto();
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index d57e2536d0..cda8a71f71 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -107,14 +107,14 @@ class Client {
// device (and its replicas if replication is enabled). Otherwise, data is
// transferred to the default device (and its replicas).
StatusOr<std::unique_ptr<GlobalData>> TransferToServer(
- const Literal& literal, const DeviceHandle* device_handle = nullptr);
+ const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr);
// Transfer the given literal to the Infeed interface of the device.
//
// device_handle and replica_id together specify a particular device; a device
// assigned for the given replica_id among the replicas that the given device
// handle belongs to.
- Status TransferToInfeed(const Literal& literal, int64 replica_id = 0,
+ Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
// Transfers from the Outfeed of the device.
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index d63d4ec7f3..3f23e52fc2 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -58,12 +58,18 @@ class LocalExecutable {
// Validates that the given arguments and options satisfy various constraints
// of the computation.
+ //
+ // The given ExecutableRunOptions override any values from legacy_flags
+ // (TF_XLA_FLAGS environment variable).
Status ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend);
// Records the computation in a SessionModule proto with the arguments used to
// invoke it, and the result. Enabled by flag: --tla_dump_executions_to.
+ //
+ // The given ServiceExecutableRunOptions override any values from legacy_flags
+ // (TF_XLA_FLAGS environment variable).
StatusOr<ScopedShapedBuffer> ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
@@ -109,6 +115,9 @@ class LocalClient : public Client {
// Build and return a LocalExecutable object. The executable is compiled using
// the given XlaComputation, argument layouts and options.
+ //
+ // The given ExecutableBuildOptions override any values from legacy_flags
+ // (TF_XLA_FLAGS environment variable).
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
const XlaComputation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index a76fdcda25..89cafa1a7d 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer(
return layout;
}
+/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
+ tensorflow::gtl::ArraySlice<int64> major_to_minor) {
+ Layout layout;
+ layout.set_format(DENSE);
+ for (int i = major_to_minor.size() - 1; i >= 0; i--) {
+ layout.add_minor_to_major(major_to_minor[i]);
+ }
+ return layout;
+}
+
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
Layout layout;
layout.set_format(SPARSE);
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index d3d6a2cc94..739bbe7367 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -36,6 +36,10 @@ class LayoutUtil {
// convenience function for protobuf construction.)
static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ // Similar to MakeLayout, but take indices in reverse order.
+ static Layout MakeLayoutFromMajorToMinor(
+ tensorflow::gtl::ArraySlice<int64> major_to_minor);
+
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 3696fdbe12..a588f4a03d 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -317,7 +317,15 @@ class NearComparator {
rel_error = std::numeric_limits<float>::infinity();
} else {
abs_error = FpAbsoluteValue(actual - expected);
- rel_error = abs_error / FpAbsoluteValue(expected);
+ // If the expected result is exactly zero, don't compute relative error;
+ // that's meaningless.
+ //
+ // TODO(b/80321728): Come up with a better way to handle this case.
+ if (expected == NativeT{}) {
+ rel_error = 0;
+ } else {
+ rel_error = abs_error / FpAbsoluteValue(expected);
+ }
}
const bool is_abs_mismatch = abs_error > error_.abs;
const bool is_rel_mismatch = rel_error > error_.rel;
@@ -716,9 +724,11 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
}
return AppendStatus(result,
- tensorflow::strings::Printf("expected: %s\nactual: %s",
- expected.ToString().c_str(),
- actual.ToString().c_str()));
+ tensorflow::strings::Printf(
+ "\nat index: %s\nexpected: %s\nactual: %s",
+ Literal::MultiIndexAsString(multi_index).c_str(),
+ ToStringTruncated(expected).c_str(),
+ ToStringTruncated(actual).c_str()));
}
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 4c560767dc..7563cc1e34 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -807,6 +807,47 @@ std::unique_ptr<Literal> LiteralBase::Relayout(
return result;
}
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Broadcast only supports arrays.");
+ }
+
+ for (int64 i = 0; i < dimensions.size(); i++) {
+ TF_RET_CHECK(shape().dimensions(i) ==
+ result_shape.dimensions(dimensions[i]));
+ }
+
+ std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+
+ // scratch_source_index is temporary storage space for the computed index into
+ // the input literal. We put it here to avoid allocating an std::vector in
+ // every iteration of ShapeUtil::ForEachIndex.
+ std::vector<int64> scratch_source_index(shape().dimensions_size());
+
+ char* dest_data = static_cast<char*>(result->untyped_data());
+ const char* source_data = static_cast<const char*>(untyped_data());
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ ShapeUtil::ForEachIndex(
+ result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ for (int64 i = 0; i < dimensions.size(); ++i) {
+ scratch_source_index[i] = output_index[dimensions[i]];
+ }
+ int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ result_shape, output_index);
+ int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape(), scratch_source_index);
+ memcpy(dest_data + primitive_size * dest_index,
+ source_data + primitive_size * source_index, primitive_size);
+ return true;
+ });
+
+ return std::move(result);
+}
+
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 609dc7a3ac..2ca9060cc7 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -277,6 +277,12 @@ class LiteralBase {
StatusOr<std::unique_ptr<Literal>> Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ // Creates a new literal by broadcasting this literal with `dimensions` to
+ // yield a literal of shape `result_shape`.
+ StatusOr<std::unique_ptr<Literal>> Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
// in the original literal, and it specifies the order of the new dimensions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 77f979a0d7..f127cee0fd 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -1810,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
+ std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 1}, {2, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
+ std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 2}, {1, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
+ std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(9);
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int32>({{9, 9}, {9, 9}}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 2698ba7d79..8fa6961d19 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -265,9 +265,9 @@ class ReferenceUtil {
const Array3D<T>& rhs,
int concatenate_dimension) {
CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
- std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()};
- std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
- std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
+ const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()};
+ const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
+ int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
for (int i = 0; i < 3; ++i) {
if (i != concatenate_dimension) {
out_dims[i] = lhs_dims[i];
@@ -299,9 +299,9 @@ class ReferenceUtil {
const Array4D<T>& rhs,
int concatenate_dimension) {
CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
- std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
- std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
- std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
+ const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
+ const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
+ int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
for (int i = 0; i < 4; ++i) {
if (i != concatenate_dimension) {
out_dims[i] = lhs_dims[i];
@@ -553,12 +553,11 @@ class ReferenceUtil {
const NativeT pad) {
CHECK_EQ(padding.dimensions_size(), 3);
- const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
- operand.n3()};
- std::vector<int64> pad_low(3);
- std::vector<int64> pad_high(3);
- std::vector<int64> pad_interior(3);
- std::vector<int64> output_bounds(3);
+ const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()};
+ int64 pad_low[3];
+ int64 pad_high[3];
+ int64 pad_interior[3];
+ int64 output_bounds[3];
for (int64 i = 0; i < 3; ++i) {
pad_low[i] = padding.dimensions(i).edge_padding_low();
pad_high[i] = padding.dimensions(i).edge_padding_high();
@@ -574,7 +573,7 @@ class ReferenceUtil {
Array3D<NativeT> result(output_bounds[0], output_bounds[1],
output_bounds[2]);
- std::vector<int> indices = {0, 0, 0};
+ int indices[] = {0, 0, 0};
for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
@@ -612,12 +611,12 @@ class ReferenceUtil {
const NativeT pad) {
CHECK_EQ(padding.dimensions_size(), 4);
- const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
- operand.n3(), operand.n4()};
- std::vector<int64> pad_low(4);
- std::vector<int64> pad_high(4);
- std::vector<int64> pad_interior(4);
- std::vector<int64> output_bounds(4);
+ const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(),
+ operand.n4()};
+ int64 pad_low[4];
+ int64 pad_high[4];
+ int64 pad_interior[4];
+ int64 output_bounds[4];
for (int64 i = 0; i < 4; ++i) {
pad_low[i] = padding.dimensions(i).edge_padding_low();
pad_high[i] = padding.dimensions(i).edge_padding_high();
diff --git a/tensorflow/compiler/xla/scanner.cc b/tensorflow/compiler/xla/scanner.cc
new file mode 100644
index 0000000000..f23a1417fc
--- /dev/null
+++ b/tensorflow/compiler/xla/scanner.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/scanner.h"
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace {
+
+// Returns true if c can be the first character in an identifier.
+bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; }
+
+// Returns true if c can be the non-first character in an identifier.
+bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; }
+
+// Returns true if str is an identifier.
+bool IsIdentifier(tensorflow::StringPiece str) {
+ if (str.empty() || !IsIdentifierFirst(str[0])) {
+ return false;
+ }
+ for (int64 i = 1; i < str.size(); ++i) {
+ if (!IsIdentifierLater(str[i])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {}
+
+bool Scanner::ok() const { return status().ok(); }
+
+const Status& Scanner::status() const { return status_; }
+
+bool Scanner::Match(tensorflow::StringPiece match) {
+ SkipWhitespace();
+ if (ok() && position_ + match.size() <= input_.size() &&
+ std::equal(match.begin(), match.end(), input_.begin() + position_)) {
+ SkipChars(match.size());
+
+ VLOG(10) << "Matched \"" << match << "\"";
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void Scanner::Expect(tensorflow::StringPiece expect) {
+ if (!Match(expect)) {
+ SetError(tensorflow::strings::StrCat("Expected \"", expect, "\"."));
+ }
+}
+
+bool Scanner::MatchReadIdentifier(string* identifier) {
+ SkipWhitespace();
+ if (!IsIdentifierFirst(PeekChar())) {
+ return false;
+ }
+ identifier->clear();
+ do {
+ *identifier += ReadChar();
+ } while (IsIdentifierLater(PeekChar()));
+
+ VLOG(10) << "Read identifier " << identifier;
+ CHECK(IsIdentifier(*identifier));
+ return true;
+}
+
+string Scanner::ReadIdentifier() {
+ string identifier;
+ if (!MatchReadIdentifier(&identifier)) {
+ SetError("Expected identifier.");
+ }
+ return identifier;
+}
+
+void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) {
+ CHECK(IsIdentifier(expect));
+
+ string identifier;
+ if (!MatchReadIdentifier(&identifier)) {
+ SetError(tensorflow::strings::StrCat("Expected identifier ", expect, "."));
+ }
+ if (identifier != expect) {
+ SetError(tensorflow::strings::StrCat("Expected identifier ", expect,
+ ", but got ", identifier, "."));
+ }
+}
+
+// Matches the end of the input, also known as End Of File (EOF).
+bool Scanner::MatchEof() {
+ SkipWhitespace();
+ return PeekChar() == EOF;
+}
+
+void Scanner::ExpectEof() {
+ if (!MatchEof()) {
+ SetError("Expected end of input.");
+ }
+}
+
+// Reads a vector of the format "(1, 2, 3)".
+std::vector<int64> Scanner::ReadIntVector() {
+ std::vector<int64> ints;
+ Expect("(");
+ if (!Match(")") && ok()) {
+ ints.push_back(ReadInt());
+ while (Match(",")) {
+ ints.push_back(ReadInt());
+ }
+ Expect(")");
+ }
+
+ VLOG(10) << "Read int vector with " << ints.size() << " elements.";
+ return ints;
+}
+
+int64 Scanner::ReadInt() {
+ bool negative = Match("-");
+ if (!PeekDigit()) {
+ SetError("Expected integer.");
+ return 0;
+ }
+
+ int64 integer = 0;
+ do {
+ integer = (ReadChar() - '0') + integer * 10;
+ } while (PeekDigit());
+ integer = negative ? -integer : integer;
+
+ VLOG(10) << "Read integer " << integer;
+ return integer;
+}
+
+void Scanner::SkipWhitespace() {
+ while (PeekWhitespace()) {
+ SkipChars(1);
+ }
+}
+
+int Scanner::ReadChar() {
+ int c = PeekChar();
+ SkipChars(1);
+
+ VLOG(20) << "Read char " << c;
+ return c;
+}
+
+int Scanner::PeekChar() const {
+ return ok() && position_ < input_.size() ? input_[position_] : EOF;
+}
+
+bool Scanner::PeekDigit() const {
+ // Do not use std::isdigit since it depends on the locale and we do not
+ // handle any digits beyond 0-9.
+ const char c = PeekChar();
+ return '0' <= c && c <= '9';
+}
+
+bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); }
+
+bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); }
+
+void Scanner::SkipChars(int64 count) {
+ CHECK_GE(count, 0);
+ position_ += count;
+}
+
+void Scanner::SetError(string error_message) {
+ // Only the first error is recorded since any later errors will likely be a
+ // consequence of the first error.
+ if (ok()) {
+ status_ = InvalidArgumentStrCat(std::move(error_message));
+ position_ = input_.size();
+ VLOG(10) << "Failed scanner with error " << status_.ToString();
+ } else {
+ VLOG(10) << "Error on already failed scanner is " << error_message;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/scanner.h b/tensorflow/compiler/xla/scanner.h
new file mode 100644
index 0000000000..86b04ae7f9
--- /dev/null
+++ b/tensorflow/compiler/xla/scanner.h
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_
+#define TENSORFLOW_COMPILER_XLA_SCANNER_H_
+
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// Simple class for parsing data. The concepts for the interface are:
+//
+// Match(x): Returns true if x is next in the input and in that case skips
+// past it. Otherwise returns false.
+//
+// Expect(x): As Match(x), but requires x to be next in the input.
+//
+// MatchReadX(x): Returns true if an X is next in the input and in that case
+// skips past it and assigns it to x. Otherwise returns false.
+//
+// ReadX(): As ReadMatchX(), but requires an X to be next in the input and
+// returns it.
+//
+// PeekX(): Returns true if an X is next in the input and does not skip
+// past it either way.
+//
+// All of these, except those that work on individual characters, skip
+// whitespace.
+//
+// If a requirement is not met, the error is available in status(). A Scanner
+// with a failed status() will behave as though the rest of the input is EOF and
+// will not record further errors after that point.
+class Scanner {
+ public:
+ Scanner(tensorflow::StringPiece input);
+
+ bool ok() const;
+ const Status& status() const;
+
+ bool Match(tensorflow::StringPiece match);
+ void Expect(tensorflow::StringPiece expect);
+
+ // Match-reads an identifier. An identifier starts with an alphabetic
+ // character or an underscore followed by any number of characters that are
+ // each alphanumeric or underscore.
+ bool MatchReadIdentifier(string* identifier);
+
+ string ReadIdentifier();
+
+ void ExpectIdentifier(tensorflow::StringPiece expect);
+
+ // Matches the end of the input, also known as End Of File (EOF).
+ bool MatchEof();
+ void ExpectEof();
+
+ // Reads a vector of the format "(1, 4, 5)".
+ std::vector<int64> ReadIntVector();
+
+ // Reads an integer. Can start with a minus but not a plus.
+ int64 ReadInt();
+
+ // Keeps skipping until encountering a non-whitespace character.
+ void SkipWhitespace();
+
+ // *** Below here are character-level methods that do not skip whitespace.
+
+ int ReadChar();
+ int PeekChar() const;
+ bool PeekDigit() const;
+ bool PeekAlnum() const;
+ bool PeekWhitespace() const;
+
+ // Skip past the next count characters.
+ void SkipChars(int64 count);
+
+ private:
+ // Sets a failed status. The input is in effect replaced with EOF after
+ // this. Only the first error is recorded.
+ void SetError(string error_message);
+
+ const tensorflow::StringPiece input_;
+ int64 position_;
+ Status status_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_
diff --git a/tensorflow/compiler/xla/scanner_test.cc b/tensorflow/compiler/xla/scanner_test.cc
new file mode 100644
index 0000000000..10cd0c6a04
--- /dev/null
+++ b/tensorflow/compiler/xla/scanner_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO(b/80179519): Fix open source build for real.
+#if 0
+#include "tensorflow/compiler/xla/scanner.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace xla {
+namespace {
+
+TEST(Scanner, Empty) {
+ Scanner scanner("");
+
+ EXPECT_EQ(scanner.PeekChar(), EOF);
+ EXPECT_TRUE(scanner.MatchEof());
+ EXPECT_TRUE(scanner.Match(""));
+ EXPECT_FALSE(scanner.Match("1"));
+ EXPECT_TRUE(scanner.ok());
+}
+
+TEST(Scanner, Prefix) {
+ Scanner scanner("1234 5");
+ EXPECT_FALSE(scanner.MatchEof());
+ EXPECT_TRUE(scanner.Match("12"));
+ EXPECT_TRUE(scanner.Match("34 "));
+ EXPECT_FALSE(scanner.MatchEof());
+ EXPECT_FALSE(scanner.Match("5 "));
+ EXPECT_TRUE(scanner.Match("5"));
+ EXPECT_TRUE(scanner.MatchEof());
+}
+
+TEST(Scanner, Whitespace) {
+ Scanner scanner(" \t\n\r 1\t2\n\n");
+
+ EXPECT_FALSE(scanner.Match(" "));
+ EXPECT_TRUE(scanner.Match("1"));
+ EXPECT_TRUE(scanner.Match("2"));
+ EXPECT_TRUE(scanner.MatchEof());
+ EXPECT_TRUE(scanner.ok());
+}
+
+TEST(Scanner, Fail) {
+ Scanner scanner("153 4q");
+
+ scanner.Expect("5");
+ EXPECT_FALSE(scanner.ok());
+ EXPECT_FALSE(scanner.status().ok());
+
+ EXPECT_TRUE(scanner.MatchEof());
+}
+
+TEST(Scanner, Identifier) {
+ Scanner scanner("1 q1 _1_ _1a= qqb");
+
+ string identifier = "foo";
+ EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier));
+ EXPECT_EQ(identifier, "foo");
+ scanner.Match("1");
+
+ EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier));
+ EXPECT_EQ(identifier, "q1");
+
+ scanner.ExpectIdentifier("_1_");
+ EXPECT_TRUE(scanner.ok());
+
+ scanner.ExpectIdentifier("_1a");
+ EXPECT_TRUE(scanner.ok());
+
+ // The = after _1a is not included in the identifier.
+ scanner.Expect("=");
+
+ // The expected identifier matches a prefix but is not the full identifier in
+ // the input.
+ EXPECT_TRUE(scanner.ok());
+ scanner.ExpectIdentifier("qq");
+ EXPECT_FALSE(scanner.ok());
+}
+
+TEST(Scanner, Int) {
+ Scanner scanner("1_2 3% -1 124345 -363 0 -0");
+ EXPECT_EQ(1, scanner.ReadInt());
+ EXPECT_TRUE(scanner.Match("_"));
+ EXPECT_EQ(2, scanner.ReadInt());
+ EXPECT_EQ(3, scanner.ReadInt());
+ EXPECT_TRUE(scanner.Match("%"));
+ EXPECT_EQ(-1, scanner.ReadInt());
+ EXPECT_EQ(124345, scanner.ReadInt());
+ EXPECT_EQ(-363, scanner.ReadInt());
+ EXPECT_EQ(0, scanner.ReadInt());
+ EXPECT_EQ(0, scanner.ReadInt());
+ EXPECT_TRUE(scanner.MatchEof());
+}
+
+TEST(Scanner, IntVector) {
+ Scanner scanner("()(0) (-1,2) ( 3 , 4 )");
+ EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty());
+ EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0));
+ EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2));
+ EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4));
+ EXPECT_TRUE(scanner.MatchEof());
+ EXPECT_TRUE(scanner.ok());
+}
+
+} // namespace
+} // namespace xla
+#endif
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index d1722644c7..5472f9a637 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -376,6 +376,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -387,7 +388,6 @@ tf_cc_test(
":hlo_matchers",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -431,6 +431,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2861,6 +2862,7 @@ tf_cc_test(
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2925,6 +2927,7 @@ cc_library(
hdrs = ["indexed_array_analysis.h"],
deps = [
":hlo",
+ ":hlo_evaluator",
":hlo_pass",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index f732ed8f39..c65c91e8e0 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -157,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleSubtract(HloInstruction* sub) override;
+ Status HandleMap(HloInstruction* map) override;
+
Status HandleMaximum(HloInstruction* maximum) override;
Status HandleMinimum(HloInstruction* minimum) override;
@@ -2188,6 +2190,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
return true;
}
+Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
+ auto* map_computation = map->to_apply();
+ auto* map_root = map_computation->root_instruction();
+ if (map_root->opcode() == HloOpcode::kParameter) {
+ ReplaceInstructionIfSameShape(
+ map, map->mutable_operand(map_root->parameter_number()));
+ return Status::OK();
+ }
+ if (map_root->opcode() == HloOpcode::kConstant) {
+ if (!ShapeUtil::IsScalar(map_root->shape())) {
+ return Status::OK();
+ }
+ auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
+ if (ShapeUtil::IsScalar(map->shape())) {
+ return ReplaceWithNewInstruction(map, std::move(clone));
+ }
+ return ReplaceWithNewInstruction(
+ map,
+ HloInstruction::CreateBroadcast(
+ map->shape(), computation_->AddInstruction(std::move(clone)), {}));
+ }
+ std::vector<HloInstruction*> new_operands;
+ for (auto* root_operand : map_root->operands()) {
+ if (root_operand->opcode() != HloOpcode::kParameter) {
+ return Status::OK();
+ }
+ new_operands.push_back(
+ map->mutable_operand(root_operand->parameter_number()));
+ }
+ auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
+ return ReplaceWithNewInstruction(map, std::move(clone));
+}
+
Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
// Match the following tree:
// min_operand operand
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 4e082877c7..d5f0afe960 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -143,6 +143,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
EXPECT_EQ(root, param0);
}
+TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
+ HloComputation::Builder builder(TestName());
+ // Create add computation.
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMap);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Add(param0, zero));
+}
+
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
index 08d0152e3c..1b8b2d2045 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
@@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape()) ||
- !bfloat16_support_->SupportsMixedPrecisions(*crs)) {
- return DefaultAction(crs);
- }
-
// First use DefaultAction() to handle the operands. It can't handle
// tuple-shaped output.
TF_RETURN_IF_ERROR(DefaultAction(crs));
+ if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
+ return Status::OK();
+ }
+
+ // If the output is not a tuple, we don't need special handling.
+ if (!ShapeUtil::IsTuple(crs->shape())) {
+ return Status::OK();
+ }
+
+ // If crs is the root instruction, we should keep its original output type.
+ // The root instruction implicitly has a use from being the result of the
+ // computation, and the code below does not take this use into account.
+ if (crs == computation_->root_instruction()) {
+ return Status::OK();
+ }
+
// Then do per-tuple-element handling on the output.
std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
crs->operand_count());
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 8b01a6c4b5..31f84e88f8 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -28,6 +28,12 @@ namespace xla {
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
tensorflow::LINKER_INITIALIZED);
+std::vector<string> Compiler::ComputeBackendConfigs(
+ const HloInstruction& hlo, se::StreamExecutor* executor) const {
+ CHECK(executor != nullptr);
+ return {};
+}
+
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*
Compiler::GetPlatformCompilerFactories() {
static auto* r = new std::map<se::Platform::Id, CompilerFactory>;
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index a4b59d1ba9..c39db58b78 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -24,9 +24,11 @@ limitations under the License.
#include <map>
#include <memory>
#include <string>
+#include <vector>
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
@@ -153,6 +155,15 @@ class Compiler {
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
DeviceMemoryAllocator* device_allocator) = 0;
+ // Returns the backend configurations that the backend will consider for the
+ // given HLO. Returns no configurations if the backend does not support
+ // configurations for the given HLO.
+ //
+ // The stream executor is passed in to provide information about the hardware
+ // that the backend configurations would be targeting.
+ virtual std::vector<string> ComputeBackendConfigs(
+ const HloInstruction& hlo, se::StreamExecutor* executor) const;
+
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index af69fc3da9..d77076546f 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -42,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint;
namespace cpu {
namespace {
-// Loads a tile of values from a 2D tensor.
-class TileLoader {
+// Provides tiled access to an in-memory rank 2 array.
+class MemoryTile {
public:
- // Constructs a TileLoader that will load a tile consisting of
+ // Constructs a MemoryTile that can operate on tiles consisting of
// `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
// `major_dim_offset` in the major dimension. The tile size along the minor
// dimension is the vector size, and that is implicitly determined by `vsl`.
- TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
+ MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
llvm::Value* matrix, int64 matrix_size_along_minor_dim,
llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
- : vsl_(vsl) {
+ : vsl_(vsl), ir_builder_(ir_builder) {
pointers_.reserve(tile_size_along_major_dim);
for (int64 i = 0; i < tile_size_along_major_dim; i++) {
llvm::Value* total_offset = ir_builder->CreateMul(
@@ -62,9 +62,10 @@ class TileLoader {
}
}
- // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
- // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
- // minor dimension.
+ // Load a tile consisting of `tile_size_along_major_dim` vectors from position
+ // {major: `major_dim_offset`, minor: `minor_dim_offset`}.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
std::vector<llvm::Value*> result;
result.reserve(pointers_.size());
@@ -74,11 +75,104 @@ class TileLoader {
return result;
}
+ // Stores `tile` to position {major: `major_dim_offset`, minor:
+ // `minor_dim_offset`}.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
+ void StoreTile(tensorflow::gtl::ArraySlice<llvm::Value*> tile,
+ llvm::Value* minor_dim_offset) const {
+ CHECK_EQ(tile.size(), pointers_.size());
+ for (int64 i = 0; i < pointers_.size(); i++) {
+ vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset);
+ }
+ }
+
+ // Loads a tile of size [`tile_size_along_major_dim`,
+ // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`,
+ // minor: `minor_dim_offset`} and then broadcasts each element into a vector
+ // of size vsl_.vector_size(). The (i,j)'th element of the return value is
+ // the (i,j)'th element in the tile broadcasted into an LLVM vector.
+ //
+ // Note: `major_dim_offset` is a parameter to the constructor.
+ std::vector<std::vector<llvm::Value*>> LoadBroadcastTile(
+ llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const {
+ std::vector<std::vector<llvm::Value*>> result;
+ result.resize(pointers_.size());
+ for (int64 i = 0; i < pointers_.size(); i++) {
+ for (int64 j = 0; j < tile_size_along_middle_dim; j++) {
+ result[i].push_back(vsl_->LoadBroadcast(
+ pointers_[i], ir_builder_->CreateAdd(minor_dim_offset,
+ ir_builder_->getInt64(j))));
+ }
+ }
+ return result;
+ }
+
private:
VectorSupportLibrary* vsl_;
+ llvm::IRBuilder<>* ir_builder_;
std::vector<llvm::Value*> pointers_;
};
+// The base class for the classes representing the GEMV emitter configurations.
+//
+// The IR emitted (modulo the LLVM values representing the input and output
+// buffers) by the row major and column major GEMV emitters should be a function
+// of their configuration. This is important because their configuration is
+// used as a key to cache the generated IR.
+class GemvConfig {
+ public:
+ // Mixin for convenience.
+ template <typename T>
+ struct User {
+ public:
+ PrimitiveType scalar_type() const {
+ return derived().config().scalar_type();
+ }
+ int64 tile_rows() const { return derived().config().tile_rows(); }
+ int64 tile_cols() const { return derived().config().tile_cols(); }
+ int64 m() const { return derived().config().m(); }
+ int64 k() const { return derived().config().k(); }
+ int64 has_addend() const { return derived().config().has_addend(); }
+
+ private:
+ const T& derived() const { return *static_cast<const T*>(this); }
+ };
+
+ PrimitiveType scalar_type() const { return scalar_type_; }
+ int64 tile_rows() const { return tile_rows_; }
+ int64 tile_cols() const { return tile_cols_; }
+ int64 m() const { return m_; }
+ int64 k() const { return k_; }
+ bool has_addend() const { return has_addend_; }
+
+ string GetCacheKey() const {
+ return tensorflow::strings::StrCat(
+ name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_",
+ tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : "");
+ }
+
+ protected:
+ explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows,
+ int64 tile_cols, int64 m, int64 k, bool has_addend)
+ : name_(std::move(name)),
+ scalar_type_(scalar_type),
+ tile_rows_(tile_rows),
+ tile_cols_(tile_cols),
+ m_(m),
+ k_(k),
+ has_addend_(has_addend) {}
+
+ private:
+ string name_;
+ PrimitiveType scalar_type_;
+ int64 tile_rows_;
+ int64 tile_cols_;
+ int64 m_;
+ int64 k_;
+ bool has_addend_;
+};
+
// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
// layout of the vector does not matter). This implementation uses a tiling
// scheme to improve performance.
@@ -140,38 +234,46 @@ class TileLoader {
// TODO(sanjoy): We should investigate if using gather loads and scatter stores
// can be used here have the same inner loop for both column-major and row-major
// matrix-vector products.
-class ColumnMajorMatrixVectorProductEmitter {
+class ColumnMajorMatrixVectorProductEmitter
+ : public GemvConfig::User<ColumnMajorMatrixVectorProductEmitter> {
public:
- ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
- int64 tile_rows, int64 tile_cols,
- int64 m, int64 k, llvm::Value* lhs,
+ class Config : public GemvConfig {
+ public:
+ explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+ int64 m, int64 k, bool has_addend)
+ : GemvConfig(/*name=*/"col_major_gemv", scalar_type,
+ /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+ /*k=*/k, /*has_addend=*/has_addend) {}
+ };
+
+ ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
llvm::Value* rhs, llvm::Value* addend,
llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
- : scalar_type_(scalar_type),
- tile_rows_(tile_rows),
- tile_cols_(tile_cols),
- m_(m),
- k_(k),
+ : config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
- vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
- CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
+ vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(),
+ ir_builder_, "") {
+ CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows())));
+ CHECK(!has_addend() || addend != nullptr);
}
void Emit();
+ const Config& config() const { return config_; }
+
private:
void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
bool is_first_column);
- TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
- return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/m_,
+ MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) {
+ return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/m(),
/*major_dim_offset=*/column_start,
/*tile_size_along_major_dim=*/column_count);
}
@@ -188,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter {
return result;
}
- void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
+ void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile,
const std::vector<llvm::Value*>& rhs_tile,
int64 columns, bool is_first_column);
void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
bool is_first_tiled_column);
- PrimitiveType scalar_type_;
- int64 tile_rows_;
- int64 tile_cols_;
- int64 m_;
- int64 k_;
+ Config config_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* addend_;
@@ -211,25 +309,25 @@ class ColumnMajorMatrixVectorProductEmitter {
void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
llvm::Value* column, int64 column_count, bool is_first_column) {
- TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
+ MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column,
/*column_count=*/column_count);
std::vector<llvm::Value*> rhs_tile =
LoadRhsTile(column, /*count=*/column_count);
- EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
+ EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile,
/*columns=*/column_count, is_first_column);
EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
}
void ColumnMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
- int64 column_remainder = k_ % tile_cols_;
- int64 column_limit = k_ - column_remainder;
+ int64 column_remainder = k() % tile_cols();
+ int64 column_limit = k() - column_remainder;
ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
+ /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
[&](llvm::Value* column, bool is_first_column) {
- EmitOuterLoopBody(column, tile_cols_, is_first_column);
+ EmitOuterLoopBody(column, tile_cols(), is_first_column);
});
if (column_remainder != 0) {
@@ -239,14 +337,14 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() {
}
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
- TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
+ MemoryTile* lhs_memory_tile, const std::vector<llvm::Value*>& rhs_tile,
int64 columns, bool is_first_column) {
- int64 row_limit = m_ - (m_ % tile_rows_);
+ int64 row_limit = m() - (m() % tile_rows());
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
- /*step=*/tile_rows_, [&](llvm::Value* row) {
+ /*step=*/tile_rows(), [&](llvm::Value* row) {
std::vector<llvm::Value*> lhs_tile =
- lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
llvm::Value* accumulator =
is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
: vsl_.GetZeroVector())
@@ -260,8 +358,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
- int64 row_start = m_ - (m_ % tile_rows_);
- if (row_start == m_) {
+ int64 row_start = m() - (m() % tile_rows());
+ if (row_start == m()) {
return;
}
@@ -281,11 +379,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
[&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
llvm::Value* total_offset =
- ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
+ ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
ksl_.For(
- "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
+ "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
@@ -365,51 +463,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
//
// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
// epilogue loop to deal with the C,D submatrix.
-class RowMajorMatrixVectorProductEmitter {
+class RowMajorMatrixVectorProductEmitter
+ : public GemvConfig::User<RowMajorMatrixVectorProductEmitter> {
public:
- RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
- int64 tile_cols, int64 m, int64 k,
- llvm::Value* lhs, llvm::Value* rhs,
- llvm::Value* addend, llvm::Value* result,
+ class Config : public GemvConfig {
+ public:
+ explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols,
+ int64 m, int64 k, bool has_addend)
+ : GemvConfig(/*name=*/"row_major_gemv", scalar_type,
+ /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m,
+ /*k=*/k, /*has_addend=*/has_addend) {}
+ };
+
+ RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* addend,
+ llvm::Value* result,
llvm::IRBuilder<>* ir_builder)
- : scalar_type_(scalar_type),
- tile_rows_(tile_rows),
- tile_cols_(tile_cols),
- m_(m),
- k_(k),
+ : config_(config),
lhs_(lhs),
rhs_(rhs),
addend_(addend),
result_(result),
ir_builder_(ir_builder),
ksl_(ir_builder_),
- vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
- CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
+ vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") {
+ CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols())));
+ CHECK(!has_addend() || addend != nullptr);
}
void Emit();
+ const Config& config() const { return config_; }
+
private:
- TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
- return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/k_,
+ MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) {
+ return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/k(),
/*major_dim_offset=*/row_start,
/*tile_size_along_major_dim=*/row_count);
}
void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
- void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
+ void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows,
std::vector<VectorVariable>* vector_accumulators);
void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators);
- PrimitiveType scalar_type_;
- int64 tile_rows_;
- int64 tile_cols_;
- int64 m_;
- int64 k_;
+ Config config_;
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* addend_;
@@ -421,7 +523,7 @@ class RowMajorMatrixVectorProductEmitter {
void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
int64 row_count) {
- TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
+ MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row,
/*row_count=*/row_count);
std::vector<VectorVariable> vector_accumulators;
std::vector<ScalarVariable> scalar_accumulators;
@@ -429,7 +531,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
}
- EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
+ EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count,
&vector_accumulators);
EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
&scalar_accumulators);
@@ -466,12 +568,12 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
void RowMajorMatrixVectorProductEmitter::Emit() {
// See the comment on the class declaration for the algorithm used here.
- int64 row_remainder = m_ % tile_rows_;
- int64 row_limit = m_ - row_remainder;
+ int64 row_remainder = m() % tile_rows();
+ int64 row_limit = m() - row_remainder;
ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
- [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
+ /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
+ [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
if (row_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
@@ -479,14 +581,14 @@ void RowMajorMatrixVectorProductEmitter::Emit() {
}
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
- TileLoader* lhs_tile_loader, int64 rows,
+ MemoryTile* lhs_memory_tile, int64 rows,
std::vector<VectorVariable>* vector_accumulators) {
- int64 column_limit = k_ - (k_ % tile_cols_);
+ int64 column_limit = k() - (k() % tile_cols());
ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
- /*step=*/tile_cols_, [&](llvm::Value* col) {
+ /*step=*/tile_cols(), [&](llvm::Value* col) {
std::vector<llvm::Value*> lhs_tile =
- lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
for (int i = 0; i < rows; i++) {
llvm::Value* old_sum = (*vector_accumulators)[i].Get();
@@ -499,18 +601,18 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* current_tile_row, int64 rows,
std::vector<ScalarVariable>* scalar_accumulators) {
- int64 column_start = k_ - (k_ % tile_cols_);
- if (column_start == k_) {
+ int64 column_start = k() - (k() % tile_cols());
+ if (column_start == k()) {
return;
}
for (int r = 0; r < rows; r++) {
llvm::Value* total_offset = ir_builder_->CreateMul(
ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
- ir_builder_->getInt64(k_));
+ ir_builder_->getInt64(k()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
+ ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
/*step=*/1, [&](llvm::Value* scalar_col) {
llvm::Value* product =
vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
@@ -543,16 +645,21 @@ class MatrixMatrixBlockPanelEmitter {
int64 k() const { return k_; }
int64 n() const { return n_; }
+ string ToString() const {
+ return tensorflow::strings::StrCat(m(), "x", k(), "x", n());
+ }
+
private:
const int64 m_;
const int64 k_;
const int64 n_;
};
- // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
- // `lhs` with `rhs` and stores the result in `result`.
+ // Represents the configuration of the GEBP emitter. The LLVM IR emitted by
+ // the emitter, modulo the LLVM values holding the input and output buffers,
+ // must be a function of the instance of `Config` passed to it.
//
- // `m`, `k` and `n` are the matrix multiplication dimensions.
+ // `dims` holds the matrix multiplication dimensions.
//
// `max_vectorization_width` is the maximum vector width (i.e. the width of
// the largest vector register we will use). This can be larger than the
@@ -561,90 +668,143 @@ class MatrixMatrixBlockPanelEmitter {
// `min_vectorization_width` is the smallest vector width the emitter will use
// -- below that it will devolve to using a scalar loop.
//
- // `k_tiling_factor` is the number of elements along the reduction dimensions
- // that we will attempt to process at once.
- explicit MatrixMatrixBlockPanelEmitter(
- llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims,
- int max_vectorization_width, int min_vectorization_width,
- int k_tiling_factor, const TargetMachineFeatures& target_machine_features,
- llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type)
+ // The innermost reduction loop executes the matrix multiply in tiles of size
+ // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`,
+ // <vectorization width>] in the RHS.
+ class Config {
+ public:
+ explicit Config(PrimitiveType scalar_type, Dimensions dims,
+ int64 max_vectorization_width,
+ int64 min_vectorization_width, int64 tile_size_m,
+ int64 tile_size_k)
+ : scalar_type_(scalar_type),
+ dims_(dims),
+ max_vectorization_width_(max_vectorization_width),
+ min_vectorization_width_(min_vectorization_width),
+ tile_size_m_(tile_size_m),
+ tile_size_k_(tile_size_k) {}
+
+ string GetCacheKey() const {
+ return tensorflow::strings::StrCat(
+ "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(),
+ "_", max_vectorization_width(), "_", min_vectorization_width(), "_",
+ tile_size_m(), "_", tile_size_k());
+ }
+
+ PrimitiveType scalar_type() const { return scalar_type_; }
+ Dimensions dims() const { return dims_; }
+ int64 max_vectorization_width() const { return max_vectorization_width_; }
+ int64 min_vectorization_width() const { return min_vectorization_width_; }
+
+ int64 tile_size_m() const { return tile_size_m_; }
+ int64 tile_size_k() const { return tile_size_k_; }
+
+ private:
+ PrimitiveType scalar_type_;
+ Dimensions dims_;
+ int64 max_vectorization_width_;
+ int64 min_vectorization_width_;
+ int64 tile_size_m_;
+ int64 tile_size_k_;
+ };
+
+ // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies
+ // `lhs` with `rhs` and stores the result in `result`.
+ explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs,
+ llvm::Value* rhs, llvm::Value* result,
+ llvm::IRBuilder<>* ir_builder)
: lhs_(lhs),
rhs_(rhs),
result_(result),
- dims_(dims),
- max_vectorization_width_(max_vectorization_width),
- min_vectorization_width_(min_vectorization_width),
- k_tiling_factor_(k_tiling_factor),
- target_machine_features_(target_machine_features),
+ config_(config),
ir_builder_(ir_builder),
- primitive_type_(primitive_type),
ksl_(ir_builder_) {
- CHECK(max_vectorization_width > 0 &&
- IsPowerOfTwo(static_cast<uint64>(max_vectorization_width)));
- CHECK(min_vectorization_width > 0 &&
- IsPowerOfTwo(static_cast<uint64>(min_vectorization_width)));
- CHECK_GT(k_tiling_factor, 0);
+ CHECK(max_vectorization_width() > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
+ CHECK(min_vectorization_width() > 0 &&
+ IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
+ CHECK_GT(tile_size_k(), 0);
}
void Emit();
private:
- // We can only iterate the `n` dimension for an extent that is divisible by
- // the vectorization width. So we emit an outer loop that first processes the
- // largest extent in `n` that is divisible by max_vectorization_width, then
- // the largest remaining extent that is divisible by max_vectorization_width /
- // 2 etc. This function emits that outermost loop.
- void EmitChunkedLoopOverN();
+ // This emits a loop that loops over the `n` dimension in multiples of
+ // `max_vectorization_width` as much as possible and then emits a remainder
+ // epilogue.
+ void EmitLoopOverN();
// This emits a loop that loops over the `k` dimension in multiples of
- // `k_tiling_factor` as much as possible and then emits a remainder epilogue.
+ // `tile_size_k` as much as possible and then emits a remainder epilogue.
void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
llvm::Value* n_end);
- // This emits the inner reduction loop. This inner reduction loop processes
- // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension
- // and [`n_start`, `n_end`) in the `n` dimension.
- void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start,
- llvm::Value* k_end, llvm::Value* n_start,
- llvm::Value* n_end, VectorSupportLibrary* vsl);
-
- llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); }
+ // This emits a loop that loops over the `m` dimension in multiples of
+ // `tile_size_m` as much as possible and then emits a remainder epilogue.
+ void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end);
+
+ // This emits the inner reduction loop. This inner reduction loop multiplies
+ // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
+ // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size
+ // [`tile_size_m`, vls->vector_width()] in the result.
+ void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start,
+ llvm::Value* m_end);
+
+ llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
+
+ Config config() const { return config_; }
+ Dimensions dims() const { return config().dims(); }
+
+ int64 max_vectorization_width() const {
+ return config().max_vectorization_width();
+ }
+ int64 min_vectorization_width() const {
+ return config().min_vectorization_width();
+ }
+ int64 tile_size_m() const { return config().tile_size_m(); }
+ int64 tile_size_k() const { return config().tile_size_k(); }
+ PrimitiveType scalar_type() const { return config().scalar_type(); }
llvm::Value* lhs_;
llvm::Value* rhs_;
llvm::Value* result_;
- Dimensions dims_;
-
- int64 max_vectorization_width_;
- int64 min_vectorization_width_;
- int64 k_tiling_factor_;
+ Config config_;
- const TargetMachineFeatures& target_machine_features_;
llvm::IRBuilder<>* ir_builder_;
- PrimitiveType primitive_type_;
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
-void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() {
- int64 current_vectorization_width = max_vectorization_width_;
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+ // We can only iterate the `n` dimension for an extent that is divisible by
+ // the vectorization width. So we emit an outer loop that first processes the
+ // largest extent in `n` that is divisible by max_vectorization_width, then
+ // the largest remaining extent that is divisible by max_vectorization_width /
+ // 2 etc.
+
+ int64 current_vectorization_width = max_vectorization_width();
int64 n_start = 0;
- while (n_start != dims_.n() &&
- current_vectorization_width >= min_vectorization_width_) {
- int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width);
+ while (n_start != dims().n() &&
+ current_vectorization_width >= min_vectorization_width()) {
+ int64 n_end = dims().n() - (dims().n() % current_vectorization_width);
if (n_start != n_end) {
- VectorSupportLibrary vsl(primitive_type_, current_vectorization_width,
+ VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
ir_builder_, "gebp");
- EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end));
+ EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
current_vectorization_width /= 2;
}
- if (n_start != dims_.n()) {
- VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp");
- ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) {
+ if (n_start != dims().n()) {
+ VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
+ ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
EmitLoopOverK(&vsl, n_i, n_i_next);
@@ -656,16 +816,30 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
llvm::Value* n_start,
llvm::Value* n_end) {
int64 k_start = 0;
- int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_);
+ int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
- EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start,
- n_end, vsl);
+ EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+ n_start, n_end);
k_start = k_end;
}
- if (k_start != dims_.k()) {
- EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()),
- n_start, n_end, vsl);
+ if (k_start != dims().k()) {
+ EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
+ GetInt64(dims().k()), n_start, n_end);
+ }
+}
+
+void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+ VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+ llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
+ const int64 m_end = dims().m() - dims().m() % tile_size_m();
+ EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ tile_size_m(), GetInt64(0), GetInt64(m_end));
+
+ if (m_end != dims().m()) {
+ EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ dims().m() - m_end, GetInt64(m_end),
+ GetInt64(dims().m()));
}
}
@@ -673,11 +847,11 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
//
// Let the LHS be:
//
-// +---+---+---+
-// | a | b | c | .
-// +---+---+---+ .
-// | | | | .
-// +---+---+---+
+// +----+----+----+
+// | a0 | b0 | c0 | .
+// +----+----+----+ .
+// | a1 | b1 | c1 | .
+// +----+----+----+
// .. ..
//
// and the RHS be:
@@ -691,72 +865,77 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
// +----+----+----+----+ .
// ...... ......
//
-// and let k_tiling_factor be 3 and the vector width (implicitly denoted by
-// `vsl`) be 4.
+// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted
+// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4]
+// matrix that we can increment the result matrix by.
//
-// Then we
+// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank
+// 3 array, L, of dimension [2,3,4]:
//
-// 1. broadcast the first row in LHS to 3 vectors of width 4
-// 2. elementwise multiply the RHS rows with these broadcasted vectors
-// 3. elementwise add them:
+// L[0,_,_] * L[1,_,_]
+// *
+// +----+----+----+----+ * +----+----+----+----+
+// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 |
+// +----+----+----+----+ * +----+----+----+----+
+// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 |
+// +----+----+----+----+ * +----+----+----+----+
+// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 |
+// +----+----+----+----+ * +----+----+----+----+
//
-// +---+---+---+---+ +----+----+----+----+
-// | a | a | a | a | * | p0 | p1 | p2 | p3 | +
-// +---+---+---+---+ +----+----+----+----+
//
-// +---+---+---+---+ +----+----+----+----+
-// | b | b | b | b | * | q0 | q1 | q2 | q3 | +
-// +---+---+---+---+ +----+----+----+----+
+// Then we FMA L[0,_,_] with the RHS to get the first row of the result and
+// L[1,_,_] with the RHS to get the second row of the result. For example,
+// L[0,_,_] is computed as:
//
-// +---+---+---+---+ +----+----+----+----+
-// | c | c | c | c | * | r0 | r1 | r2 | r3 |
-// +---+---+---+---+ +----+----+----+----+
+// +----+----+----+----+ +----+----+----+----+
+// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | +
+// +----+----+----+----+ +----+----+----+----+
//
-// to get:
+// +----+----+----+----+ +----+----+----+----+
+// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | +
+// +----+----+----+----+ +----+----+----+----+
//
-// +----------------+----------------+----------------+----------------+
-// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 |
-// +----------------+----------------+----------------+----------------+
+// +----+----+----+----+ +----+----+----+----+
+// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 |
+// +----+----+----+----+ +----+----+----+----+
//
-// which we increment into the appropriate region in the result.
-void MatrixMatrixBlockPanelEmitter::EmitInnerLoop(
- int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) {
- ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) {
- // This outer loop iterates over all of the M dimension
- llvm::Value* result_row_begin = vsl->ComputeOffsetPointer(
- result_, /*offset_elements=*/m_i, /*scale=*/dims_.n());
- llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer(
- lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k());
-
- ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) {
- // broadcasted_a is the broadcasted set of vectors denoted as <a,a,a,a>,
- // <b,b,b,b> etc. in the diagram.
- std::vector<llvm::Value*> broadcasted_a;
- broadcasted_a.reserve(k_tiling_factor);
- for (int i = 0; i < k_tiling_factor; i++) {
- broadcasted_a.push_back(vsl->LoadBroadcast(
- lhs_row_begin, ir_builder_->CreateAdd(getInt64(i), k_i)));
- }
-
- // rhs_loader will be used to load the tile off of the RHS, denoted as
- // <<p0,p1,p2,p3>,<q0,q1,q2,q3> ...> in the diagram.
- TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i,
- k_tiling_factor);
+// to get:
+//
+// +-------------------+-------------------+-------------------+---------
+// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
+// +-------------------+-------------------+-------------------+---------
+void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop(
+ VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
+ llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
+ ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
+ MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_,
+ /*matrix_size_along_minor_dim=*/dims().n(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/dims().k(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+
+ ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
+ MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i,
+ tile_size_k);
+ std::vector<std::vector<llvm::Value*>> lhs_tile =
+ lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
ksl_.For(
"dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
- // This loop iterates over the N dimension. It loads the tile from
- // RHS, does the FMA resulting in the
- // <a*p0+b*q0+c*r0,a*p1+b*q1+c*r1,...> in the diagram and increments
- // the result.
- std::vector<llvm::Value*> tile = rhs_loader.LoadTile(n_i);
- llvm::Value* result_accumulator =
- vsl->LoadVector(result_row_begin, n_i);
- for (int i = 0; i < tile.size(); i++) {
- result_accumulator =
- vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator);
+ std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
+ std::vector<llvm::Value*> result_tile =
+ result_memory_tile.LoadTile(n_i);
+ for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
+ for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
+ result_tile[r_m_i] =
+ vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
+ result_tile[r_m_i]);
+ }
}
- vsl->StoreVector(result_accumulator, result_row_begin, n_i);
+ result_memory_tile.StoreTile(result_tile, n_i);
});
});
});
@@ -827,8 +1006,6 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
return false;
}
- VLOG(2) << "Emitting GEBP kernel in LLVM IR";
-
llvm::Value* lhs = lhs_array_.GetBasePointer();
llvm::Value* rhs = rhs_array_.GetBasePointer();
llvm::Value* target = target_array_.GetBasePointer();
@@ -846,14 +1023,36 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
target, ir_builder_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
- MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k,
- /*n=*/n);
- MatrixMatrixBlockPanelEmitter gebp_emitter(
- /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims,
- /*max_vectorization_width=*/8, /*min_vectorization_width=*/4,
- /*k_tiling_factor=*/8, target_machine_features_, ir_builder_,
- primitive_type);
- gebp_emitter.Emit();
+ int64 max_vector_width =
+ target_machine_features_.vector_register_num_elements(
+ *ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+
+ MatrixMatrixBlockPanelEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
+ /*max_vectorization_width=*/max_vector_width,
+ /*min_vectorization_width=*/std::min<int64>(4, max_vector_width),
+ /*tile_size_m=*/3, /*tile_size_k=*/5);
+
+ VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
+ << config.GetCacheKey();
+
+ const bool enable_fast_math =
+ hlo_module_config_.debug_options().xla_enable_fast_math();
+ const bool optimize_for_size =
+ options::OptimizeForSizeRequested(hlo_module_config_);
+
+ KernelSupportLibrary::EmitAndCallOutlinedKernel(
+ /*enable_fast_math=*/enable_fast_math,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs, rhs, target,
+ [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) {
+ MatrixMatrixBlockPanelEmitter gebp_emitter(
+ config, /*lhs=*/lhs, /*rhs=*/rhs,
+ /*result=*/target, ir_builder_);
+ gebp_emitter.Emit();
+ });
+
return true;
}
@@ -942,47 +1141,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
if (is_column_major_matrix_vector) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
- int64 tile_rows = vector_register_element_size;
- int64 tile_cols = tiling_factor;
-
- string kernel_name = tensorflow::strings::StrCat(
- "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
- "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+ ColumnMajorMatrixVectorProductEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor,
+ /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
- lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
- [this, tile_rows, tile_cols, m, k, primitive_type](
- llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
- llvm::Value* result_op) {
+ [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+ llvm::Value* addend_op, llvm::Value* result_op) {
ColumnMajorMatrixVectorProductEmitter emitter(
- primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
- addend_op, result_op, ir_builder_);
+ config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
emitter.Emit();
});
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
- int64 tile_rows = tiling_factor;
- int64 tile_cols = vector_register_element_size;
-
- string kernel_name = tensorflow::strings::StrCat(
- "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
- "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : "");
+ RowMajorMatrixVectorProductEmitter::Config config(
+ /*scalar_type=*/primitive_type,
+ /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size,
+ /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr);
KernelSupportLibrary::EmitAndCallOutlinedKernel(
/*enable_fast_math=*/enable_fast_math,
- /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name,
- lhs_op, rhs_op,
+ /*optimize_for_size=*/optimize_for_size, ir_builder_,
+ config.GetCacheKey(), lhs_op, rhs_op,
addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op,
- [this, tile_rows, tile_cols, m, k, primitive_type](
- llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op,
- llvm::Value* result_op) {
+ [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op,
+ llvm::Value* addend_op, llvm::Value* result_op) {
RowMajorMatrixVectorProductEmitter emitter(
- primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
- addend_op, result_op, ir_builder_);
+ config, lhs_op, rhs_op, addend_op, result_op, ir_builder_);
emitter.Emit();
});
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 4012f87f2b..2794930248 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -338,6 +338,7 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -401,6 +402,9 @@ tf_cc_test(
srcs = ["instruction_fusion_test.cc"],
deps = [
":instruction_fusion",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -542,6 +546,8 @@ cc_library(
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
+ "//tensorflow/compiler/xla/service:while_loop_constant_sinking",
+ "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
@@ -587,14 +593,18 @@ cc_library(
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
deps = [
+ ":gpu_options",
":ir_emission_utils",
+ ":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -691,6 +701,27 @@ cc_library(
],
)
+cc_library(
+ name = "gpu_options",
+ srcs = ["gpu_options.cc"],
+ hdrs = ["gpu_options.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "stream_executor_util",
+ srcs = ["stream_executor_util.cc"],
+ hdrs = ["stream_executor_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
tf_cc_test(
name = "gpu_hlo_support_checker_test",
srcs = ["gpu_hlo_support_checker_test.cc"],
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 10b4c3de89..0645fbb3ad 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -113,8 +115,17 @@ Status RunCudnnConvolution(
// cuDNN's convolution APIs support the BDYX layout for activations/output and
// the OIYX layout for weights.
+ DataLayout input_dl;
+ FilterLayout filter_dl;
+ DataLayout output_dl;
+
+ TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
+ XlaConvLayoutsToStreamExecutorLayouts(
+ dnums, input_shape.layout(), filter_shape.layout(),
+ output_shape.layout()));
+
BatchDescriptor input_descriptor(effective_num_dimensions);
- input_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ input_descriptor.set_layout(input_dl)
.set_feature_map_count(
input_shape.dimensions(dnums.input_feature_dimension()))
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
@@ -126,7 +137,7 @@ Status RunCudnnConvolution(
}
FilterDescriptor filter_descriptor(effective_num_dimensions);
- filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
+ filter_descriptor.set_layout(filter_dl)
.set_input_feature_map_count(
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
.set_output_feature_map_count(
@@ -149,7 +160,7 @@ Status RunCudnnConvolution(
}
BatchDescriptor output_descriptor(effective_num_dimensions);
- output_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ output_descriptor.set_layout(output_dl)
.set_feature_map_count(
output_shape.dimensions(dnums.output_feature_dimension()))
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index d50153d8a3..b857219807 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -73,6 +73,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -157,11 +159,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
pass.AddPass<CudnnBatchNormRewriter>();
}
+ // TODO(kramerb): Remove use_fusion once instruction fusion can create
+ // multi-output fusions from the unfused expander output.
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
- /*use_fusion=*/false);
+ /*use_fusion=*/true);
// Rewrite gather ops into smaller ones.
pass.AddPass<GatherExpander>();
@@ -174,6 +178,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
pass.AddPass<TupleSimplifier>();
+ pass.AddPass<WhileLoopConstantSinking>();
pass.AddPass<WhileLoopSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<ReshapeMover>();
@@ -200,18 +205,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("layout_assignment");
+ pipeline.AddPass<GpuLayoutAssignment>(
+ hlo_module->mutable_device_entry_computation_layout(), stream_exec);
+
+ // The LayoutAssignment pass may leave behind kCopy instructions which are
+ // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
+ return true;
+ });
// Choose the fastest algorithm for each conv.
//
- // In theory doing this here is way too early: It needs to happen after
- // layout assignment, because the layout of the inputs/outputs affects the
- // speed of the conv. But currently we only allow only one input/output
- // layout when calling cudnn, so there's no ambiguity.
- //
- // We pick the algorithm at this early stage so we can generate better HLO.
- // After CudnnConvolutionRewriter, our convolutions are CustomCalls which
- // return a tuple (conv_result, scratch_memory), and the each conv uses 0
- // bytes of scratch:
+ // We pick the algorithm before fusion so we can generate better HLO. After
+ // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a
+ // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+ // scratch:
//
// customcall = (f32[...], f32[0])
// return gte(customcall, 0)
@@ -227,35 +242,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// The new tuple and gte instructions then be simplified away, because
// nobody is expected to use the scratch value.
//
- // However, if we were to run CudnnConvolutionAlgorithmPicker after layout
- // assignment, fusion would already have run, and the gte(customcall, 0)
- // would probably already be into a fusion node. We can't simplify across
- // HloComputation boundaries, so in this case we wouldn't be able to
- // simplify away the new_tuple bits.
- //
- // We'll need to revisit this if we ever allow multiple layouts for the
- // inputs/outputs of a cudnn convolution.
+ // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion
+ // the gte(customcall, 0) would probably already be into a fusion node. We
+ // can't simplify across HloComputation boundaries, so in this case we
+ // wouldn't be able to simplify away the new_tuple bits.
pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
device_allocator);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
- pipeline.AddPass<HloDCE>();
- TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
- }
-
- {
- HloPassPipeline pipeline("layout_assignment");
- pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout());
-
- // The LayoutAssignment pass may leave behind kCopy instructions which are
- // duplicate or NOPs, so remove them with algebraic simplification and CSE.
- pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
- /*is_layout_sensitive=*/true,
- /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
- return true;
- });
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
@@ -282,6 +277,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
}
}
+
+ {
+ // Do an aggressive LICM pass over while loops. In particular, this hoists
+ // constants that were sunk by WhileLoopConstantSinking. Leaving them in
+ // the while loop may result in unnecessary copies.
+ HloPassPipeline pipeline("while-loop-licm");
+ pipeline.AddPass<WhileLoopInvariantCodeMotion>(true);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 89f1e62588..178457721a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -18,31 +18,72 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace gpu {
-// cuDNN convolutions are called with specific layouts on the input, output,
-// and filter:
-//
-// input: DataLayout::kBatchDepthYX
-// output: DataLayout::kBatchDepthYX
-// filter: FilterLayout::kOutputInputYX
-//
-// The order dimensions in the constant name is major-to-minor (eg, the
-// most-major dimension of the input is batch, most-minor is X). The
-// specific dimension numbers these named dimensions correspond to is
-// determined by the ConvolutionDimensionNumbers argument. Y is spatial
-// dimension 0, and X is spatial dimension 1.
-//
-// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
-static Status AddBackendConstraintsToDnnConvCustomCall(
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::FilterLayout;
+
+static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
+ int major, minor;
+ CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
+ &minor));
+ return major >= 7;
+}
+
+// Returns (input, filter, output) layouts.
+static std::tuple<DataLayout, FilterLayout, DataLayout>
+HeuristicLayoutAssignment(const HloInstruction* instr,
+ stream_executor::StreamExecutor* stream_executor) {
+ // DataLayout and FilterLayout uses weird enum names. Translations:
+ // N <=> Batch or Output
+ // C <=> Depth or Input
+ // H <=> Y
+ // W <=> X
+ //
+ // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW.
+
+ // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
+ // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
+ // changes, as well as the hardware updates.
+ if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
+ IsVoltaOrLater(*stream_executor))) {
+ return std::make_tuple(DataLayout::kBatchDepthYX,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ }
+ VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
+ // For BackwardInput that has stride, full NHWC layouts run significantly
+ // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
+ //
+ // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
+ // NHWC).
+ if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+ window_util::HasStride(instr->window())) {
+ return std::make_tuple(DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ }
+ return std::make_tuple(DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth);
+}
+
+// Adds layout constraints on the cudnn custom-call instruction. The layout
+// constraints are represented in terms of minor_to_major fields of both
+// operands and the output shape. Depending on the underlying algorithm, one of
+// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
+Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
HloInstruction* instr, LayoutConstraints* constraints) {
CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
Shape input_shape;
@@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall(
<< instr->custom_call_target();
}
- // Construct minor-to-major dimension orders for operands and result.
- // cuDNN's convolution APIs support the BDYX layout for activations/output
- // and the OIYX layout for weights.
- // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
- // calls after we switch to cuDNN v5.
- const ConvolutionDimensionNumbers& dimension_numbers =
- instr->convolution_dimension_numbers();
- std::vector<int64> input_layout;
- for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
- }
- input_layout.push_back(dimension_numbers.input_feature_dimension());
- input_layout.push_back(dimension_numbers.input_batch_dimension());
- *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout);
-
- std::vector<int64> filter_layout;
- for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i));
- }
- filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension());
- filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension());
- *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
-
- std::vector<int64> output_layout;
- for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
+ {
+ DataLayout input;
+ FilterLayout filter;
+ DataLayout output;
+ if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, stream_executor_);
+ } else {
+ input = DataLayout::kBatchDepthYX;
+ filter = FilterLayout::kOutputInputYX;
+ output = DataLayout::kBatchDepthYX;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
+ *output_shape.mutable_layout()),
+ StreamExecutorConvLayoutsToXlaLayouts(
+ instr->convolution_dimension_numbers(), input, filter, output));
}
- output_layout.push_back(dimension_numbers.output_feature_dimension());
- output_layout.push_back(dimension_numbers.output_batch_dimension());
- *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout);
// The custom call returns a tuple of (actual_result, scratch_buffer);
// call_result_buf is the logical buffer for actual_result, the thing that
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 86a3a7111f..ce24af1cf8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
@@ -27,8 +28,10 @@ namespace gpu {
// layout constraints for operands and results of library calls.
class GpuLayoutAssignment : public LayoutAssignment {
public:
- explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
- : LayoutAssignment(entry_computation_layout) {}
+ explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+ se::StreamExecutor* stream_executor)
+ : LayoutAssignment(entry_computation_layout),
+ stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
protected:
@@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment {
LayoutConstraints* constraints) override;
bool CustomCallRequiresMajorFirstLayout(
const HloInstruction* instruction) override;
+
+ private:
+ Status AddBackendConstraintsToDnnConvCustomCall(
+ HloInstruction* instr, LayoutConstraints* constraints);
+
+ se::StreamExecutor* stream_executor_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 4c45d2e94a..e48165c142 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
*computation_layout.mutable_result_layout() =
ShapeLayout(result_shape_with_layout);
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
*computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
{result_shape, offset_scale_shape, offset_scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
{result_shape, scale_shape, scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc
new file mode 100644
index 0000000000..35b4b4e20b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+namespace gpu {
+
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
+ return !config.debug_options().xla_backend_extra_options().count(
+ "xla_gpu_experimental_conv_disable_layout_heuristic");
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h
new file mode 100644
index 0000000000..498d4a9495
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+// Helper functions for querying options that are specific to the GPU backend.
+
+namespace xla {
+namespace gpu {
+
+// Returns true if we should use heuristics to assign convolution layouts, as
+// opposed to always assigning NCHW.
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index 3ddc1c0789..ae310beefa 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -49,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector<InfeedBuffer*>& buffers) {
}
InfeedBuffer* InfeedManager::BlockingDequeueBuffer() {
- tensorflow::mutex_lock l(mu_);
- while (enqueued_buffer_.empty()) {
- cv_.wait(l);
+ bool became_empty = false;
+ InfeedBuffer* current_buffer;
+ {
+ tensorflow::mutex_lock l(mu_);
+ while (enqueued_buffer_.empty()) {
+ cv_.wait(l);
+ }
+ current_buffer = enqueued_buffer_.front();
+ enqueued_buffer_.pop_front();
+ dequeued_buffer_.insert(current_buffer);
+ if (enqueued_buffer_.empty()) {
+ became_empty = true;
+ }
+ }
+ if (became_empty) {
+ for (const auto& callback : on_empty_callbacks_) {
+ callback();
+ }
}
- InfeedBuffer* current_buffer = enqueued_buffer_.front();
- enqueued_buffer_.pop_front();
- dequeued_buffer_.insert(current_buffer);
return current_buffer;
}
@@ -88,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) {
return host_to_device_stream_.get();
}
+void InfeedManager::RegisterOnEmptyCallback(std::function<void()> callback) {
+ on_empty_callbacks_.push_back(std::move(callback));
+}
+
InfeedManager* GetOrCreateInfeedManager() {
static InfeedManager* manager = new InfeedManager;
return manager;
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
index d5f2216d46..a3fc15cfe3 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
@@ -21,6 +21,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
#include <deque>
+#include <vector>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -100,6 +101,10 @@ class InfeedManager {
// returns null.
se::Stream* GetStream(se::StreamExecutor* executor);
+ // Registers a callback that will be called when 'enqueued_buffer_' becomes
+ // empty.
+ void RegisterOnEmptyCallback(std::function<void()> callback);
+
private:
// TODO(b/30467474): Revisit if this mutex becomes a point of
// contention.
@@ -122,6 +127,10 @@ class InfeedManager {
// Executor that the host_to_device_stream belongs to. Not owned.
se::StreamExecutor* host_to_device_executor_;
+
+ // List of callbacks which will be called when 'enqueued_buffer_' becomes
+ // empty.
+ std::vector<std::function<void()>> on_empty_callbacks_;
};
// Singleton creator-or-accessor: Returns the GPU infeed manager.
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 5d5bef6b57..36a1b82a26 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -177,6 +177,26 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
InstructionFusion::ShouldFuse(consumer, operand_index);
}
+bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
+ int64 operand_index) {
+ const HloInstruction* producer = consumer->operand(operand_index);
+ // The IR emitter has limited support for non-loop fusions with multi output
+ // at present.
+ // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion.
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) {
+ return false;
+ }
+ // Multi-output fusion requires instructions with compatible shapes.
+ if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) {
+ return false;
+ }
+ // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for
+ // multi-output fusion. In particular, do not check whether an instruction is
+ // expensive to duplicate, since this doesn't matter here.
+ return GpuInstructionFusion::ShouldFuse(consumer, operand_index);
+}
+
HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
const HloInstruction* producer, const HloInstruction* consumer) {
if (IsReductionToVector(*consumer)) {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
index 9fb06b0a24..f629d9ff2c 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
@@ -31,6 +31,9 @@ class GpuInstructionFusion : public InstructionFusion {
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
+ bool ShouldFuseIntoMultiOutput(HloInstruction* consumer,
+ int64 operand_index) override;
+
HloInstruction::FusionKind ChooseKind(
const HloInstruction* producer, const HloInstruction* consumer) override;
};
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 760e0e90f5..ec60f3a167 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -15,9 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -281,7 +284,8 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
- EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()));
+ EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion()))
+ << module->ToString();
}
// Compute sum(100/p0), where p0 has type s32, twice. Check that the division
@@ -308,7 +312,8 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
- .ValueOrDie());
+ .ValueOrDie())
+ << module->ToString();
}
TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
@@ -337,5 +342,244 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
op::Broadcast(op::Parameter())));
}
+// Counts the HLO ops with a given op code in the specified module.
+static int Count(const HloModule& module, HloOpcode op) {
+ int count = 0;
+ for (const auto* computation : module.computations()) {
+ for (const auto* instruction : computation->instructions()) {
+ if (instruction->opcode() == op) {
+ ++count;
+ }
+ }
+ }
+ return count;
+}
+
+// Returns an HLO instruction from the given computation with the op code.
+static StatusOr<const HloInstruction*> FindHloInstruction(
+ const HloComputation& computation, HloOpcode op) {
+ for (const auto* instruction : computation.instructions()) {
+ if (instruction->opcode() == op) {
+ return instruction;
+ }
+ }
+ return NotFound(
+ "Computation '%s' does not contain an instruction with op code '%s'.",
+ computation.name().c_str(), HloOpcodeString(op).c_str());
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusion) {
+ // sub --> add --> tuple
+ // \---------------/
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ p2 = f32[4,3]{1,0} parameter(2)
+ sub = f32[4,3]{1,0} subtract(p0, p2)
+ add = f32[4,3]{1,0} add(sub, p1)
+ ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
+ })")
+ .ValueOrDie();
+
+ ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+
+ // Expect that there is one multi-output fusion and subtract has not been
+ // duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+ EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1);
+ TF_ASSERT_OK_AND_ASSIGN(
+ const HloInstruction* fusion,
+ FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion));
+ EXPECT_THAT(
+ fusion->fused_expression_root(),
+ op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract()));
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) {
+ // tanh --> add --> tuple
+ // \---------------/
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ tanh = f32[4,3]{1,0} tanh(p0)
+ add = f32[4,3]{1,0} add(tanh, p1)
+ ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add)
+ })")
+ .ValueOrDie();
+
+ // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh.
+ ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusion2) {
+ // sub --> add1 --\--------\
+ // \----------> add2 --> tuple
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ p2 = f32[4,3]{1,0} parameter(2)
+ sub = f32[4,3]{1,0} subtract(p0, p2)
+ add1 = f32[4,3]{1,0} add(sub, p1)
+ add2 = f32[4,3]{1,0} add(sub, add1)
+ ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2)
+ })")
+ .ValueOrDie();
+
+ ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+
+ // Expect that there is one multi-output fusion and subtract has not been
+ // duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+ EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1);
+ TF_ASSERT_OK_AND_ASSIGN(
+ const HloInstruction* fusion,
+ FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Add(op::Subtract(), op::Add()),
+ op::Add(op::Subtract(), op::Parameter())));
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusion3) {
+ // sub --> add1 ----\--------\
+ // \ --> add2 --> add3 --> tuple
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ p2 = f32[4,3]{1,0} parameter(2)
+ p3 = f32[4,3]{1,0} parameter(3)
+ sub = f32[4,3]{1,0} subtract(p0, p2)
+ add1 = f32[4,3]{1,0} add(sub, p1)
+ add2 = f32[4,3]{1,0} add(p2, sub)
+ add3 = f32[4,3]{1,0} add(add1, add2)
+ ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2)
+ })")
+ .ValueOrDie();
+
+ ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+
+ // Expect that there is one multi-output fusion and subtract has not been
+ // duplicated.
+ EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1);
+ EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1);
+ TF_ASSERT_OK_AND_ASSIGN(
+ const HloInstruction* fusion,
+ FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion));
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Add(op::Add(), op::Add()),
+ op::Add(op::Parameter(), op::Subtract())));
+}
+
+TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) {
+ // sub --> mul ---\
+ // \--> call --> add --> tuple
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ c = f32[] constant(42)
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ sub = f32[4,3]{1,0} subtract(p0, p1)
+ mul = f32[4,3]{1,0} multiply(sub, c)
+ call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo"
+ add = f32[4,3]{1,0} add(mul, call)
+ ROOT tuple = (f32[4,3]{1,0}) tuple(add)
+ })")
+ .ValueOrDie();
+
+ ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ // Visit instructions in post order to detect cycles.
+ // TODO(tjoerg): Add cycle detection to the HloVerifier.
+ class DummyVisitor : public DfsHloVisitorWithDefault {
+ public:
+ DummyVisitor() {}
+ Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
+ return Status::OK();
+ }
+ } visitor;
+ for (const HloComputation* computation : module->MakeComputationPostOrder()) {
+ // Accept will return a FailedPrecondition when a cycle is detected.
+ EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok());
+ }
+}
+
+TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) {
+ // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3])
+ // \-------------------------/
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY OutputFusion {
+ p0 = f32[2,3]{1,0} parameter(0)
+ p1 = f32[4,3]{1,0} parameter(1)
+ p2 = f32[2,3]{1,0} parameter(2)
+ sub = f32[2,3]{1,0} subtract(p0, p2)
+ add = f32[4,3]{1,0} add(sub, p1)
+ ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add)
+ })")
+ .ValueOrDie();
+
+ // Multi-output fusion requires shapes to be compatible. Since `sub` and `add`
+ // have incompatible shapes, expect that no multi-output fusion happens.
+ ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+}
+
+TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+
+ add_computation {
+ add_lhs = f32[] parameter(0)
+ add_rhs = f32[] parameter(1)
+ ROOT add_root = f32[] add(add_lhs, add_rhs)
+ }
+
+ fused_computation {
+ p1 = f32[10] parameter(0)
+ zero = f32[] constant(0)
+ ROOT f2_root = f32[] reduce(p1, zero), dimensions={0},
+ to_apply=add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[10] parameter(0)
+ mul = f32[10] multiply(p0, p0)
+ fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation
+ ROOT tuple = (f32[10], f32[]) tuple(fusion, mul)
+ })")
+ .ValueOrDie();
+
+ // Multi-output fusion is not supported for non-loop fusions at present. Since
+ // `fused_computation` is a input fusion, expect no multi-output fusion to
+ // happen.
+ ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie())
+ << module->ToString();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index b0accc08d4..e55dfc6dae 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
return bindings_.GetBasePointer(inst);
}
- // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice.
- BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const {
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
return ir_emitter_context_->buffer_assignment()
- .GetUniqueTopLevelSlice(&hlo)
+ .GetUniqueSlice(&hlo, index)
.ConsumeValueOrDie();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 55d4c1d13d..ae4e305b80 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -79,6 +79,7 @@ namespace {
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
+using tensorflow::gtl::InlinedVector;
using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;
using tensorflow::strings::StrCat;
@@ -499,12 +500,24 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// initializes the output array to the initial value of the reduce.
if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
switch (root->opcode()) {
+ case HloOpcode::kTuple:
case HloOpcode::kReduce: {
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
- BuildInitializerThunk(fusion));
std::vector<std::unique_ptr<Thunk>> thunks;
- thunks.push_back(std::move(initializer_thunk));
+ ArraySlice<HloInstruction*> reduces =
+ root->opcode() == HloOpcode::kTuple
+ ? root->operands()
+ : ArraySlice<HloInstruction*>(&root, 1);
+
+ // For multi-output fusion emit an initializer for each tuple element.
+ // Otherwise it's sufficient to just initialize the single output.
+ for (int i = 0, e = reduces.size(); i != e; ++i) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Thunk> initializer_thunk,
+ BuildInitializerThunk(
+ fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i})));
+ thunks.push_back(std::move(initializer_thunk));
+ }
thunks.push_back(BuildKernelThunk(fusion));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
@@ -518,11 +531,34 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
- Shape input_shape = root->operand(0)->shape();
- return EmitReductionToVector(
- root, input_shape, fused_emitter.GetGenerator(root->operand(0)),
- fused_emitter.GetGenerator(root->operand(1)), root->dimensions(),
- root->to_apply());
+ // For multi-output fusion CHECK the constraints and feed all the
+ // reduces into a single loop code generator. Single-output reduce
+ // fusion is a special case of that.
+ InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
+ InlinedVector<llvm_ir::ElementGenerator, 1> init_value_gens;
+ InlinedVector<HloComputation*, 1> reducers;
+ for (const HloInstruction* reduce : reduces) {
+ CHECK_EQ(HloOpcode::kReduce, reduce->opcode());
+ // TODO(kramerb): CHECK that layouts are equal. Currently this
+ // breaks multioutputfusion_test. The test has pre-fused
+ // instructions, but layout_assignment will not assign any layouts
+ // for instructions inside of a fused computation. It just removes
+ // the layouts instead.
+ CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape()));
+ CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(),
+ reduce->operand(0)->shape()));
+ CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(),
+ reduce->operand(1)->shape()));
+ CHECK(reduces[0]->dimensions() == reduce->dimensions());
+ input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0)));
+ init_value_gens.push_back(
+ fused_emitter.GetGenerator(reduce->operand(1)));
+ reducers.push_back(reduce->to_apply());
+ }
+ const Shape& input_shape = reduces[0]->operand(0)->shape();
+ return EmitReductionToVector(reduces[0], input_shape, input_gens,
+ init_value_gens, reduces[0]->dimensions(),
+ reducers);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@@ -909,8 +945,9 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
int64 num_elems = ShapeUtil::ElementsIn(input_shape);
@@ -962,14 +999,19 @@ Status IrEmitterUnnested::EmitReductionToScalar(
//
auto loop_body_emitter =
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
- {
- TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
- init_value_gen(llvm_ir::IrArray::Index({})));
+ std::vector<llvm::Value*> partial_reduction_result_addresses;
+ for (int i = 0; i != num_reduces; ++i) {
+ llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](llvm_ir::IrArray::Index({})));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ partial_reduction_result_addresses.push_back(
+ partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
@@ -1002,11 +1044,16 @@ Status IrEmitterUnnested::EmitReductionToScalar(
llvm_ir::IrArray::Index input_index(
/*linear=*/x, input_shape, &ir_builder_);
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
- TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
- return (EmitCallToNestedComputation(
- *reducer, {partial_reduction_result_address, input_address},
- partial_reduction_result_address));
+ for (int i = 0; i != num_reduces; ++i) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+ input_gens[i](input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], input_address},
+ partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
};
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
@@ -1041,20 +1088,24 @@ Status IrEmitterUnnested::EmitReductionToScalar(
: element_ir_type;
for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_address,
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
element_ir_type, nullptr, "result_from_other_lane");
- ir_builder_.CreateStore(
- EmitShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance), &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducer, {partial_reduction_result_address, result_from_other_lane},
- partial_reduction_result_address));
+ for (int i = 0; i != num_reduces; ++i) {
+ llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
+ ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
+ ir_builder_.CreateStore(
+ EmitShuffleDown(partial_reduction_result,
+ ir_builder_.getInt32(shuffle_distance),
+ &ir_builder_),
+ ir_builder_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], result_from_other_lane},
+ partial_reduction_result_addresses[i]));
+ }
}
const HloInstruction* output =
@@ -1070,14 +1121,25 @@ Status IrEmitterUnnested::EmitReductionToScalar(
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
- llvm::Value* output_address =
- GetIrArray(*output, *output)
- .EmitArrayElementAddress(
- llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0),
- output->shape(), &ir_builder_),
- &ir_builder_, "output_element_address");
- return EmitAtomicOperationForNestedComputation(
- *reducer, output_address, partial_reduction_result_address);
+
+ for (int i = 0; i != num_reduces; ++i) {
+ ShapeIndex output_shape_index;
+ if (output->IsMultiOutputFusion()) {
+ output_shape_index = {i};
+ }
+ llvm::Value* output_address =
+ GetIrArray(*output, *output, output_shape_index)
+ .EmitArrayElementAddress(
+ llvm_ir::IrArray::Index(
+ /*linear=*/ir_builder_.getInt64(0),
+ ShapeUtil::GetSubshape(output->shape(),
+ output_shape_index),
+ &ir_builder_),
+ &ir_builder_, "output_element_address");
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
};
// Emit a parallel loop that iterates through all input tiles, one per thread.
@@ -1097,8 +1159,9 @@ Status IrEmitterUnnested::EmitReductionToScalar(
Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
// Divide the input matrix into tiles of size Kx1. For example, when the
// input matrix is 4x4 and K=2, the tiled matrix looks like
//
@@ -1108,9 +1171,13 @@ Status IrEmitterUnnested::EmitColumnReduction(
// 4567 // Numbers indicate tile IDs.
//
// Each tile is first partially reduced to a scalar by a thread, and then the
- // scalar is accumulated to the output vector using atomic operations. We
- // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel.
- constexpr int64 kTileSize = 16;
+ // scalar is accumulated to the output vector using atomic operations.
+ //
+ // We choose 128 as the tile size based on empirical evidence. It's big enough
+ // to reduce the amount of atomic adds in the end, maximizing the memory
+ // bandwidth.
+ constexpr int64 kTileSize = 128;
+
// If the height is not a multiple of the tile size, we pad the bottom of the
// input matrix.
const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
@@ -1140,15 +1207,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// }
auto loop_body_emitter =
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
- {
- TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
- init_value_gen(llvm_ir::IrArray::Index({})));
+ std::vector<llvm::Value*> partial_reduction_result_addresses;
+ for (int i = 0; i != num_reduces; ++i) {
+ llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](llvm_ir::IrArray::Index({})));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ partial_reduction_result_addresses.push_back(
+ partial_reduction_result_address);
}
// Emit an inner for-loop that partially reduces the elements in the given
@@ -1206,13 +1278,17 @@ Status IrEmitterUnnested::EmitColumnReduction(
.SourceIndexOfTranspose(normalized_input_shape, input_shape,
transpose_dimension_mapping,
&ir_builder_);
- TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
- input_gen(input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ for (int i = 0; i != num_reduces; ++i) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+ input_gens[i](input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], input_address},
+ partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
}
- return (EmitCallToNestedComputation(
- *reducer, {partial_reduction_result_address, input_address},
- partial_reduction_result_address));
};
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
@@ -1241,13 +1317,24 @@ Status IrEmitterUnnested::EmitColumnReduction(
&ir_builder_);
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
- llvm::Value* output_address =
- GetIrArray(*output, *output)
- .EmitArrayElementAddress(
- llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_),
- &ir_builder_, "output_element_address");
- return EmitAtomicOperationForNestedComputation(
- *reducer, output_address, partial_reduction_result_address);
+ for (int i = 0; i != num_reduces; ++i) {
+ ShapeIndex output_shape_index;
+ if (output->IsMultiOutputFusion()) {
+ output_shape_index = {i};
+ }
+ llvm::Value* output_address =
+ GetIrArray(*output, *output, output_shape_index)
+ .EmitArrayElementAddress(
+ llvm_ir::IrArray::Index(
+ x,
+ ShapeUtil::GetSubshape(output->shape(),
+ output_shape_index),
+ &ir_builder_),
+ &ir_builder_, "output_element_address");
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
};
// Emit a parallel loop that iterate through all input tiles.
@@ -1267,8 +1354,10 @@ Status IrEmitterUnnested::EmitColumnReduction(
Status IrEmitterUnnested::EmitRowReduction(
int64 depth, int64 height, int64 width, HloInstruction* reduce,
- const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+ const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
// A naive algorithm is:
// 1. Divide the input tensor into tiles of size 1x1xK.
// 2. Partially reduces each tile to a scalar using one thread.
@@ -1358,15 +1447,20 @@ Status IrEmitterUnnested::EmitRowReduction(
auto loop_body_emitter =
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
- llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
- element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
- {
- TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
- init_value_gen(llvm_ir::IrArray::Index({})));
+ std::vector<llvm::Value*> partial_reduction_result_addresses;
+ for (int i = 0; i != num_reduces; ++i) {
+ llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
+ init_value_gens[i](llvm_ir::IrArray::Index({})));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ partial_reduction_result_addresses.push_back(
+ partial_reduction_result_address);
}
// Emit an inner for-loop that partially reduces the elements in the given
@@ -1449,13 +1543,17 @@ Status IrEmitterUnnested::EmitRowReduction(
.SourceIndexOfTranspose(normalized_input_shape, input_shape,
transpose_dimension_mapping,
&ir_builder_);
- TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
- input_gen(input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
+ for (int i = 0; i != num_reduces; ++i) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+ input_gens[i](input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], input_address},
+ partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
}
- return EmitCallToNestedComputation(
- *reducer, {partial_reduction_result_address, input_address},
- partial_reduction_result_address);
};
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
@@ -1483,20 +1581,24 @@ Status IrEmitterUnnested::EmitRowReduction(
: element_ir_type;
for (int shuffle_distance = 16; shuffle_distance >= 1;
shuffle_distance /= 2) {
- llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
- ir_builder_.CreateBitCast(partial_reduction_result_address,
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
element_ir_type, nullptr, "result_from_other_lane");
- ir_builder_.CreateStore(
- EmitShuffleDown(partial_reduction_result,
- ir_builder_.getInt32(shuffle_distance), &ir_builder_),
- ir_builder_.CreateBitCast(result_from_other_lane,
- shuffle_ir_type->getPointerTo()));
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducer, {partial_reduction_result_address, result_from_other_lane},
- partial_reduction_result_address));
+ for (int i = 0; i != num_reduces; ++i) {
+ llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
+ ir_builder_.CreateBitCast(partial_reduction_result_addresses[i],
+ shuffle_ir_type->getPointerTo()),
+ "partial_reduction_result");
+ ir_builder_.CreateStore(
+ EmitShuffleDown(partial_reduction_result,
+ ir_builder_.getInt32(shuffle_distance),
+ &ir_builder_),
+ ir_builder_.CreateBitCast(result_from_other_lane,
+ shuffle_ir_type->getPointerTo()));
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], result_from_other_lane},
+ partial_reduction_result_addresses[i]));
+ }
}
const HloInstruction* output =
@@ -1510,13 +1612,24 @@ Status IrEmitterUnnested::EmitRowReduction(
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
- llvm::Value* output_address =
- GetIrArray(*output, *output)
- .EmitArrayElementAddress(
- llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_),
- &ir_builder_, "output_element_address");
- return EmitAtomicOperationForNestedComputation(
- *reducer, output_address, partial_reduction_result_address);
+ for (int i = 0; i != num_reduces; ++i) {
+ ShapeIndex output_shape_index;
+ if (output->IsMultiOutputFusion()) {
+ output_shape_index = {i};
+ }
+ llvm::Value* output_address =
+ GetIrArray(*output, *output, output_shape_index)
+ .EmitArrayElementAddress(
+ llvm_ir::IrArray::Index(
+ y,
+ ShapeUtil::GetSubshape(output->shape(),
+ output_shape_index),
+ &ir_builder_),
+ &ir_builder_, "output_element_address");
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ }
+ return Status::OK();
};
// Emit a parallel loop that iterates through every input tiles.
@@ -1543,10 +1656,10 @@ Status IrEmitterUnnested::EmitRowReduction(
// elementwise.
Status IrEmitterUnnested::EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reducer) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
// a fused kReduce).
@@ -1581,8 +1694,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
// `EmitReductionToVector`, we only need to check whether the minormost
// dimension of the input is to keep.
if (input_dims_to_keep.empty()) {
- return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen,
- reducer);
+ return EmitReductionToScalar(reduce, input_shape, input_gens,
+ init_value_gens, reducers);
} else if (input_dims_to_keep.front() ==
LayoutUtil::Minor(input_shape.layout(), 0)) {
// Column reduction. Treat the result of "input" as a matrix whose width
@@ -1599,8 +1712,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
height *= input_shape.dimensions(input_dim);
}
}
- return EmitColumnReduction(height, width, reduce, input_shape, input_gen,
- init_value_gen, reducer);
+ return EmitColumnReduction(height, width, reduce, input_shape, input_gens,
+ init_value_gens, reducers);
} else {
// Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
// 3D tensor. The size of dimension 1 (the height) is the size of the
@@ -1626,7 +1739,7 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
const int64 height = ShapeUtil::ElementsIn(reduce->shape());
return EmitRowReduction(depth, height, width, reduce, input_shape,
- input_gen, init_value_gen, reducer);
+ input_gens, init_value_gens, reducers);
}
}
@@ -1650,16 +1763,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
- reduce, input->shape(),
- [&](const llvm_ir::IrArray::Index& index) {
+ reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*input, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
- },
- [&](const llvm_ir::IrArray::Index& index) {
+ }},
+ {[&](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
- },
- dimensions_to_reduce, reducer);
+ }},
+ dimensions_to_reduce, {reducer});
}
thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
@@ -2324,7 +2436,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
- const HloInstruction* hlo) {
+ const HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
const HloInstruction* init_value = [&] {
@@ -2333,6 +2445,11 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
return inst->operand(2);
case HloOpcode::kReduce:
return inst->operand(1);
+ case HloOpcode::kTuple:
+ CHECK(hlo->IsMultiOutputFusion() &&
+ inst->operand(index.back())->opcode() == HloOpcode::kReduce);
+ // For multi-output fusion look through the tuple.
+ return inst->operand(index.back())->operand(1);
default:
LOG(FATAL) << "Opcode " << inst->opcode()
<< " should not need an initializer.";
@@ -2356,7 +2473,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo), hlo)};
+ return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
@@ -2372,8 +2489,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
pattern16 = literal_bytes.front();
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
- return {MakeUnique<Memset32BitValueThunk>(pattern32,
- GetAllocationSlice(*hlo), hlo)};
+ return {MakeUnique<Memset32BitValueThunk>(
+ pattern32, GetAllocationSlice(*hlo, index), hlo)};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2383,8 +2500,8 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
literal_bytes.size() - 4) == 0) {
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
- return {MakeUnique<Memset32BitValueThunk>(word, GetAllocationSlice(*hlo),
- hlo)};
+ return {MakeUnique<Memset32BitValueThunk>(
+ word, GetAllocationSlice(*hlo, index), hlo)};
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index e42c5e8686..b41eaa303b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -110,28 +110,31 @@ class IrEmitterUnnested : public IrEmitter {
// `EmitReductionToVector`. Note that input shape might not be
// [height x width], but can be bitcast to [height x weight] with "height"
// being the major dimension.
- Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
- const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
+ Status EmitColumnReduction(
+ int64 height, int64 width, HloInstruction* reduce,
+ const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
// vector of shape [height]. Other parameters have the same meaning as those
// of `EmitReductionToVector`. Note that input shape might not be
// [depth x height x width], but can be bitcast to [depth x height x weight]
// with "depth" being the most major dimension.
- Status EmitRowReduction(int64 depth, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
+ Status EmitRowReduction(
+ int64 depth, int64 height, int64 width, HloInstruction* reduce,
+ const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
- Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
- HloComputation* reducer);
+ Status EmitReductionToScalar(
+ HloInstruction* reduce, const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers);
// Figures out whether `reduce` is a row or column reduction, and which
// dimensions to reduce, and calls either `EmitRowReduction` or
@@ -141,13 +144,16 @@ class IrEmitterUnnested : public IrEmitter {
// generate elements of the input and the initial value. Other parameters mean
// the same as for `HandleReduce`.
//
+ // Multiple reduces can be emitted in the same loop, assuming they have the
+ // same input and output shapes, and the same reduce dimensions.
+ //
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
- const llvm_ir::ElementGenerator& input_gen,
- const llvm_ir::ElementGenerator& init_value_gen,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
+ tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- HloComputation* reducer);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers);
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
@@ -166,7 +172,7 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a thunk that, given a reduce or select-and-scatter op, initializes
// its memory to the appropriate initial value.
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
- const HloInstruction* hlo);
+ const HloInstruction* hlo, const ShapeIndex& index = {});
// Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
new file mode 100644
index 0000000000..a50ddf6ac6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+
+namespace xla {
+namespace gpu {
+
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::DataLayoutString;
+using stream_executor::dnn::FilterLayout;
+using stream_executor::dnn::FilterLayoutString;
+
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+ DataLayout input, FilterLayout filter,
+ DataLayout output) {
+ std::vector<int64> input_layout;
+ switch (input) {
+ case DataLayout::kBatchDepthYX:
+ input_layout.push_back(dnums.input_batch_dimension());
+ input_layout.push_back(dnums.input_feature_dimension());
+ input_layout.insert(input_layout.end(),
+ dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end());
+ break;
+ case DataLayout::kBatchYXDepth:
+ input_layout.push_back(dnums.input_batch_dimension());
+ input_layout.insert(input_layout.end(),
+ dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end());
+ input_layout.push_back(dnums.input_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid input layout: ",
+ DataLayoutString(input));
+ }
+
+ std::vector<int64> filter_layout;
+ switch (filter) {
+ case FilterLayout::kOutputInputYX:
+ filter_layout.push_back(dnums.kernel_output_feature_dimension());
+ filter_layout.push_back(dnums.kernel_input_feature_dimension());
+ filter_layout.insert(filter_layout.end(),
+ dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end());
+ break;
+ case FilterLayout::kOutputYXInput:
+ filter_layout.push_back(dnums.kernel_output_feature_dimension());
+ filter_layout.insert(filter_layout.end(),
+ dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end());
+ filter_layout.push_back(dnums.kernel_input_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid filter layout: ",
+ FilterLayoutString(filter));
+ }
+
+ std::vector<int64> output_layout;
+ switch (output) {
+ case DataLayout::kBatchDepthYX:
+ output_layout.push_back(dnums.output_batch_dimension());
+ output_layout.push_back(dnums.output_feature_dimension());
+ output_layout.insert(output_layout.end(),
+ dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end());
+ break;
+ case DataLayout::kBatchYXDepth:
+ output_layout.push_back(dnums.output_batch_dimension());
+ output_layout.insert(output_layout.end(),
+ dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end());
+ output_layout.push_back(dnums.output_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid output layout: ",
+ DataLayoutString(output));
+ }
+
+ return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
+ LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
+ LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
+}
+
+StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+ const Layout& input, const Layout& filter,
+ const Layout& output) {
+ Layout nchw_input, nchw_filter, nchw_output;
+ std::tie(nchw_input, nchw_filter, nchw_output) =
+ StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX)
+ .ConsumeValueOrDie();
+
+ Layout nhwc_input, nhwc_filter, nhwc_output;
+ std::tie(nhwc_input, nhwc_filter, nhwc_output) =
+ StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth)
+ .ConsumeValueOrDie();
+
+ DataLayout input_layout;
+ if (LayoutUtil::Equal(input, nchw_input)) {
+ input_layout = DataLayout::kBatchDepthYX;
+ } else if (LayoutUtil::Equal(input, nhwc_input)) {
+ input_layout = DataLayout::kBatchYXDepth;
+ } else {
+ return tensorflow::errors::Internal("Invalid input layout: ",
+ input.ShortDebugString());
+ }
+
+ FilterLayout filter_layout;
+ if (LayoutUtil::Equal(filter, nchw_filter)) {
+ filter_layout = FilterLayout::kOutputInputYX;
+ } else if (LayoutUtil::Equal(filter, nhwc_filter)) {
+ filter_layout = FilterLayout::kOutputYXInput;
+ } else {
+ return tensorflow::errors::Internal("Invalid filter layout: ",
+ filter.ShortDebugString());
+ }
+
+ DataLayout output_layout;
+ if (LayoutUtil::Equal(output, nchw_output)) {
+ output_layout = DataLayout::kBatchDepthYX;
+ } else if (LayoutUtil::Equal(output, nhwc_output)) {
+ output_layout = DataLayout::kBatchYXDepth;
+ } else {
+ return tensorflow::errors::Internal("Invalid output layout: ",
+ output.ShortDebugString());
+ }
+
+ return std::make_tuple(input_layout, filter_layout, output_layout);
+}
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
new file mode 100644
index 0000000000..8218f4fd11
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+// Helper functions for interacting with StreamExecutor.
+
+namespace xla {
+namespace gpu {
+
+// Returns (input, filter, output) XLA Layout protos given the StreamExecutor
+// layouts.
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+ stream_executor::dnn::DataLayout input,
+ stream_executor::dnn::FilterLayout filter,
+ stream_executor::dnn::DataLayout output);
+
+// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts.
+StatusOr<std::tuple<stream_executor::dnn::DataLayout,
+ stream_executor::dnn::FilterLayout,
+ stream_executor::dnn::DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+ const Layout& input, const Layout& filter,
+ const Layout& output);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index b06e6c9f3e..cc130a4900 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -363,7 +363,7 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
bool HloDataflowAnalysis::UpdateConditionalValueSet(
HloInstruction* conditional) {
CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
- std::vector<const InstructionValueSet*> inputs = {
+ const InstructionValueSet* const inputs[] = {
&GetInstructionValueSet(
conditional->true_computation()->root_instruction()),
&GetInstructionValueSet(
@@ -538,7 +538,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
- std::vector<const InstructionValueSet*> inputs = {
+ const InstructionValueSet* const inputs[] = {
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
&GetInstructionValueSet(xla_while->operand(0))};
if (ssa_form_) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index fa59a5fb20..e90eb0669d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -309,6 +309,35 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
return result;
}
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+ HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.CloneToUnique());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
+ rhs_instr.get());
+ auto result = Evaluate(cloned_instruction.get());
+
+ cloned_instruction->DetachFromOperands();
+ return result;
+}
+
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+ HloOpcode opcode, const Literal& operand) {
+ std::unique_ptr<HloInstruction> operand_instr =
+ HloInstruction::CreateConstant(operand.CloneToUnique());
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
+ auto result = Evaluate(cloned_instruction.get());
+
+ cloned_instruction->DetachFromOperands();
+ return result;
+}
+
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
@@ -859,6 +888,28 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
return Status::OK();
}
+Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
+ const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
+
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(operand.shape()))
+ << "broadcast dimensions is of size: " << broadcast->dimensions().size()
+ << " and rank of operand_to_broadcast is: "
+ << ShapeUtil::Rank(operand.shape());
+ // Checks that operand's dimensions are the same as the broadcast's
+ // dimensions along the dimensions to be broadcasted.
+ for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
+ TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
+ operand.shape().dimensions(i));
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[broadcast],
+ operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
+
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 566d53a414..b53d5644de 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -109,6 +109,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
const std::unordered_map<const HloInstruction*, const Literal*>&
substitutions);
+ StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
+ HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+
+ StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
+ HloOpcode opcode, const Literal& operand);
+
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
// class.
@@ -166,6 +172,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index ae5b5e0412..84b4ead2dd 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -262,13 +262,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) {
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
auto expected = Literal::CreateR2<float>({{1, -1}, {-1, 1}});
TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
- use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20);
+ use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesSinR2) {
auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
auto expected = Literal::CreateR2<float>({{0, 0}, {0, 0}});
TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
- use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20);
+ use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesNotR2) {
auto operand =
@@ -333,7 +333,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
result->EachCell<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
- EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0x1.0P-5);
+ EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
});
}
@@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
- EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 024e8751f7..82ee77e1ae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleRound<ReturnT>(round);
}
- Status HandleBroadcast(HloInstruction* broadcast) override {
- const Literal& operand_to_broadcast =
- parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
- std::vector<int64> broadcast_indices(
- ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
-
- TF_RET_CHECK(broadcast->dimensions().size() ==
- ShapeUtil::Rank(operand_to_broadcast.shape()))
- << "broadcast dimensions is of size: " << broadcast->dimensions().size()
- << " and rank of operand_to_broadcast is: "
- << ShapeUtil::Rank(operand_to_broadcast.shape());
- // Checks that operand's dimensions are the same as the broadcast's
- // dimensions along the dimensions to be broadcasted.
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand_to_broadcast.shape().dimensions(i));
- }
-
- auto output = MakeUnique<Literal>(broadcast->shape());
- TF_RETURN_IF_ERROR(output->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
- }
- return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
- }));
- parent_->evaluated_[broadcast] = std::move(output);
- return Status::OK();
- }
-
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
@@ -1482,11 +1452,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Evaluate computation with specified literal operands.
auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
- std::vector<const Literal*> args = {result_val_literal.get(),
- curr_val_literal.get()};
std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ embedded_evaluator
+ .Evaluate<const Literal*>(
+ *function,
+ {result_val_literal.get(), curr_val_literal.get()})
.ConsumeValueOrDie();
// Clear visit states so that we can use the evaluator again on
// the same computation.
@@ -1685,10 +1656,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Literal::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
Literal::CreateR0<ReturnT>(result_val);
- const std::vector<const Literal*> args = {
- result_val_literal.get(), curr_val_literal.get()};
std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ embedded_evaluator
+ .Evaluate<const Literal*>(
+ *function,
+ {result_val_literal.get(), curr_val_literal.get()})
.ConsumeValueOrDie();
// Clear visit states so that the we can use the evaluate again
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 17e3c405f1..a2cb21c09b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -321,13 +321,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
- const DebugOptions& debug_options, bool show_metadata,
- bool show_backend_config, const HloExecutionProfile* profile,
- NodeFilter filter)
+ const DebugOptions& debug_options, bool show_backend_config,
+ const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(std::string(label)),
debug_options_(debug_options),
- show_metadata_(show_metadata),
show_backend_config_(show_backend_config),
profile_(profile),
filter_(std::move(filter)) {}
@@ -395,7 +393,6 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
- const bool show_metadata_;
const bool show_backend_config_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@@ -791,16 +788,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
// Build the text that will be displayed inside the node.
string node_body = node_label;
- for (const string& s : {trivial_subcomputation, node_metadata,
- node_backend_config, extra_info, inlined_constants}) {
+ for (const string& s : {trivial_subcomputation, node_backend_config,
+ extra_info, inlined_constants}) {
if (!s.empty()) {
StrAppend(&node_body, "<br/>", s);
}
}
- return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
+ return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
"\n",
- InstructionId(instr), node_body, node_shape,
+ InstructionId(instr), node_body, node_shape, node_metadata,
NodeColorAttributes(color));
}
@@ -1068,10 +1065,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
}
string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
- if (!show_metadata_) {
- return "";
- }
-
std::vector<string> lines;
if (!instr->metadata().op_name().empty()) {
lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
@@ -1154,6 +1147,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
return Join(lines, "<br/>");
}
+// Gets the total number of array elements in the given shape. For tuples, this
+// is the sum of all the sizes of all of the array elements recursively in the
+// tuple.
+static int64 TotalElementsInShape(const Shape& shape) {
+ int64 elems = 0;
+ ShapeUtil::ForEachSubshape(
+ shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) {
+ if (ShapeUtil::IsArray(subshape)) {
+ elems += ShapeUtil::ElementsIn(subshape);
+ }
+ });
+ return elems;
+}
+
void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
int64 operand_num, bool control_edge = false) {
@@ -1173,9 +1180,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
} else if (control_edge) {
edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
}
- const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)";
+
+ // We print "small" arrays using a hollow arrowhead and "large" arrays using
+ // a filled arrowhead. For now, we use an arbitrary cutoff for what "big"
+ // means.
+ bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
+
+ const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
- from->name(), to->name(), edge_label));
+ (is_big_array ? "normal" : "empty"), from->name(),
+ to->name(), edge_label));
};
// Add edges from instr's operands to instr. Parameters within fusion
@@ -1425,7 +1439,7 @@ string ExportGraph(const string& graph,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile,
- bool show_metadata, bool show_backend_config) {
+ bool show_backend_config) {
GraphRendererInterface::GraphKind graph_kind;
string graph;
if (debug_options.xla_hlo_dump_as_graphdef()) {
@@ -1436,8 +1450,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph_kind = GraphRendererInterface::TF_GRAPHDEF;
} else {
graph =
- HloDotDumper(&computation, label, debug_options, show_metadata,
- show_backend_config, hlo_execution_profile, NodeFilter())
+ HloDotDumper(&computation, label, debug_options, show_backend_config,
+ hlo_execution_profile, NodeFilter())
.Dump();
graph_kind = GraphRendererInterface::DOT_GRAPH;
}
@@ -1449,15 +1463,15 @@ string DumpGraph(const HloComputation& computation, const string& label,
}
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata, bool show_backend_config) {
+ bool show_backend_config) {
auto debug_options = node.GetModule()->config().debug_options();
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
NodeFilter filter = MakeNodeFilter(&node, radius);
- string graph = HloDotDumper(node.parent(), label, debug_options,
- show_metadata, show_backend_config,
- /*profile=*/nullptr, filter)
- .Dump();
+ string graph =
+ HloDotDumper(node.parent(), label, debug_options, show_backend_config,
+ /*profile=*/nullptr, filter)
+ .Dump();
return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index fc8e1468ac..0b11f34abb 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr,
- bool show_metadata = false, bool show_backend_config = false);
+ bool show_backend_config = false);
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
@@ -64,7 +64,6 @@ string DumpGraph(const HloComputation& computation, const string& label,
// (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata = false,
bool show_backend_config = false);
// Dumps the HloModule::ToString() as a file into the provided directory path
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index c33bdadf1c..dfefad3634 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -324,6 +325,12 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
return ::testing::MakeMatcher(
new ::xla::testing::HloShardingMatcher(sharding));
}
+// Matcher for Sharding from sharding string
+inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
+ tensorflow::StringPiece sharding) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
+ xla::tools::ParseSharding(sharding).ValueOrDie()));
+}
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
return ::testing::MakeMatcher(
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 016cc01e33..1d10e3c4fe 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
@@ -147,6 +146,18 @@ TEST(HloMatchersTest, ShardingMatcher) {
"param.1");
p1->set_sharding(HloSharding::AssignDevice(1));
+ auto tuple_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}),
+ ShapeUtil::MakeShape(F32, {11})});
+ auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2");
+ Array<int64> assignment({2});
+ assignment.SetValues({0, 1});
+ auto sharding = HloSharding::Tuple(
+ tuple_shape,
+ {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment),
+ HloSharding::AssignDevice(1), HloSharding::Replicate()});
+ p2->set_sharding(sharding);
+
EXPECT_THAT(p0.get(), op::NoSharding());
EXPECT_THAT(p0.get(),
::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
@@ -155,6 +166,11 @@ TEST(HloMatchersTest, ShardingMatcher) {
::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
+ EXPECT_THAT(
+ p2.get(),
+ op::Sharding(
+ "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}"));
+
EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
"%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
"{maximal device=1})");
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 8e167633bb..4738e46f8a 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion(
const HloInstruction* instruction) {
BitVector& bit_vector = GetBitVector(instruction);
tmp_bit_vector_ = bit_vector;
+ SetReachabilityToUnionHelper(inputs, instruction, &bit_vector);
+ return bit_vector != tmp_bit_vector_;
+}
+void HloReachabilityMap::FastSetReachabilityToUnion(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction) {
+ SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction));
+}
+
+void HloReachabilityMap::SetReachabilityToUnionHelper(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
- bit_vector.SetToZero();
+ bit_vector->SetToZero();
}
- bit_vector.Set(GetIndex(instruction));
+ bit_vector->Set(GetIndex(instruction));
for (const HloInstruction* input : inputs) {
- bit_vector.OrWith(GetBitVector(input));
+ bit_vector->OrWith(GetBitVector(input));
}
-
- return bit_vector != tmp_bit_vector_;
}
void HloReachabilityMap::SetReachable(const HloInstruction* a,
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 553ec11f6f..69bb2b3cee 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -57,6 +57,11 @@ class HloReachabilityMap {
tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
const HloInstruction* instruction);
+ // As above, but faster because it does not check if the reachability changed.
+ void FastSetReachabilityToUnion(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction);
+
// Sets entry so that IsReachable(a, b) will return true
//
// !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
@@ -133,6 +138,11 @@ class HloReachabilityMap {
return bit_vectors_[GetIndex(instruction)];
}
+ // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
+ void SetReachabilityToUnionHelper(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
+ const HloInstruction* instruction, BitVector* bit_vector);
+
// Return the index of the given instruction. The value is used to index into
// the vector of BitVectors and the BitVectors themselves.
int GetIndex(const HloInstruction* instruction) const {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 2a601ec3d1..7127adf456 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -94,8 +94,8 @@ HloRunner::~HloRunner() {}
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
std::unique_ptr<HloModule> module,
- const tensorflow::gtl::ArraySlice<Literal*> arguments,
- bool run_hlo_passes) {
+ const tensorflow::gtl::ArraySlice<Literal*> arguments, bool run_hlo_passes,
+ ExecutionProfile* profile) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), run_hlo_passes));
se::Stream stream(backend().default_stream_executor());
@@ -127,7 +127,7 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
executable->ExecuteOnStreamWrapper(
- &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs));
+ &service_run_options, /*profile=*/profile, argument_buffer_ptrs));
auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
stream.parent(), result);
@@ -141,6 +141,18 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
return result_literal;
}
+StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+ std::unique_ptr<HloModule> module,
+ const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
+ bool run_hlo_passes, ExecutionProfile* profile) {
+ // Construct a vector of plain pointers for the arguments.
+ std::vector<Literal*> argument_pointers;
+ c_transform(
+ arguments, std::back_inserter(argument_pointers),
+ [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+ return Execute(std::move(module), argument_pointers, run_hlo_passes, profile);
+}
+
StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const ReplicatedExecuteOptions& options) {
@@ -295,4 +307,8 @@ Backend& HloRunner::backend() {
return *backend_;
}
+const Backend& HloRunner::backend() const {
+ return const_cast<HloRunner*>(this)->backend();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 53f7c6fe4a..aa62659ac3 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -110,19 +110,12 @@ class HloRunner {
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<Literal*> arguments,
- bool run_hlo_passes = true);
+ bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
StatusOr<std::unique_ptr<Literal>> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>> arguments,
- bool run_hlo_passes = true) {
- // Construct a vector of plain pointers for the arguments.
- std::vector<Literal*> argument_pointers;
- c_transform(
- arguments, std::back_inserter(argument_pointers),
- [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
- return Execute(std::move(module), argument_pointers, run_hlo_passes);
- }
+ bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
@@ -137,6 +130,7 @@ class HloRunner {
// This creates the backend lazily so it's possible to instantiate an
// HloRunner in a program without any backends linked in.
Backend& backend();
+ const Backend& backend() const;
private:
// Creates an executable object given an HLO module. If run_hlo_passes is
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 854aa94319..68b2cde83a 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -299,6 +299,8 @@ class ListScheduler {
auto best_it = ready_queue.end();
--best_it;
const HloInstruction* best = best_it->second.instruction;
+ VLOG(2) << "Schedule instruction: " << best->ToShortString()
+ << " Bytes freed: " << best_it->first.first;
ready_queue.erase(best_it);
ready_instructions.erase(best);
schedule.push_back(best);
@@ -437,6 +439,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
// simply users-1 for each instruction. By subtracting 1, we're saying that
// instructions with no users or a single user don't count; instructions with
// lots of fan-out will be visited earlier.
+ int64 cumulative_total_size = 0;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
@@ -449,12 +452,21 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
int64 logical_buffer_size = SumLogicalBufferSizes(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
total_sizes[hlo] = logical_buffer_size;
+ cumulative_total_size += logical_buffer_size;
tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
total_sizes[hlo] += total_sizes[operand];
}
+ // total_sizes[hlo] transitively includes the sizes of all nodes that
+ // lead to it. But computation is a DAG, so we are double-counting nodes,
+ // which can lead to overflows for large programs.
+ // cumulative_total_size caps the size to prevent overflows.
+ // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
+ // why we care about transitive sizes; when scheduling a node, its input
+ // and output buffers should be all that matters, not its "history".
+ total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
}
CHECK_EQ(extra_users.size(), computation.instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count());
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index c018ba2ffc..0bc930f9ea 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -289,5 +289,100 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
}
+TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
+ auto builder = HloComputation::Builder(TestName());
+ const auto TUPLE_SIZE = 1;
+ const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
+
+ // Wrap lit in abs because constants are considered free by
+ // IgnoreInstruction, and it skews the accounting.
+ auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({1, 1, 1, 1, 1, 1})));
+ auto abs_const = builder.AddInstruction(
+ HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
+
+ auto abs_abs1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
+ tensorflow::gtl::ArraySlice<HloInstruction*>({abs_abs1})));
+ auto tuple_elm = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+ auto abs_abs2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+
+ builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
+ tuple_elm, abs_abs2));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ CreateMemoryMinimizingSequence(*module,
+ [&TUPLE_SIZE](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), TUPLE_SIZE);
+ },
+ ListMemoryScheduler));
+
+ // Verify that all instructions are in the sequence.
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ sequence.at(module->entry_computation()).size());
+ SequentialHloOrdering ordering(module.get(), sequence);
+ // tuple allocates the tuple buffer and doesn't free anything.
+ // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
+ // abs_abs2 should be scheduled before tuple by List.
+ EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
+}
+
+TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
+ const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
+ HloComputation::Builder builder(TestName());
+
+ auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({1, 1, 1, 1, 1})));
+ auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({1, 2, 3, 4, 5})));
+ auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({0, 2, 4, 6, 8})));
+
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
+ auto mul = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
+
+ auto tuple_elm = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
+
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+
+ auto fusion = computation->CreateFusionInstruction(
+ {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
+
+ TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
+ CreateMemoryMinimizingSequence(
+ *module,
+ [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
+ },
+ ListMemoryScheduler));
+
+ // Verify that all instructions are in the sequence.
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ sequence.at(module->entry_computation()).size());
+ SequentialHloOrdering ordering(module.get(), sequence);
+ // fusion allocates memory for the tuple elements and doesn't free anything,
+ // so it's more expensive than exp.
+ EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 3bf0d25efb..94d1a3226b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_sharding.h"
-
#include <set>
#include <unordered_map>
#include <utility>
@@ -25,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -312,5 +311,48 @@ TEST_F(HloShardingTest, OstreamTest) {
EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
}
+TEST_F(HloShardingTest, Parse) {
+ auto check = [](const HloSharding& sharding) {
+ TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
+ tools::ParseSharding(sharding.ToString()));
+ EXPECT_EQ(sharding, parsed_sharding);
+ };
+ check(HloSharding::Replicate());
+ check(HloSharding::AssignDevice(2));
+ check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+ Array4D<int64>({{{{0}, {1}}}})));
+ // Empty tuple.
+ check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {}));
+ {
+ // Non-nested tuple.
+ auto tuple_shape =
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
+ ShapeUtil::MakeShape(F32, {3, 5, 7}),
+ ShapeUtil::MakeShape(F32, {3, 7})});
+ check(HloSharding::Tuple(
+ tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+ Array4D<int64>({{{{0}, {1}}}})),
+ HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
+ }
+ {
+ // Nested tuple.
+ auto tuple_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
+ ShapeUtil::MakeShape(F32, {3, 7})})});
+ std::vector<HloSharding> leaf_shardings = {
+ HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
+ Array4D<int64>({{{{0}, {1}}}})),
+ HloSharding::Replicate(), HloSharding::AssignDevice(1)};
+ ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
+ // Assign leaf_shardings to sharding_tree leaves.
+ auto it = leaf_shardings.begin();
+ for (auto& index_to_sharding : sharding_tree.leaves()) {
+ index_to_sharding.second = *it++;
+ }
+ check(HloSharding::Tuple(sharding_tree));
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 15b2d8f499..8b3fa6c157 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -28,9 +29,11 @@ using Analysis = IndexedArrayAnalysis;
using UnknownArray = Analysis::UnknownArray;
using ConstantArray = Analysis::ConstantArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
+using tensorflow::gtl::ArraySlice;
+using tensorflow::str_util::Join;
} // namespace
-string IndexedArrayAnalysis::ToString(Array* root) {
+string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
switch (root->kind()) {
case Array::kUnknown: {
auto* unknown_tensor = root->as<UnknownArray>();
@@ -39,6 +42,12 @@ string IndexedArrayAnalysis::ToString(Array* root) {
}
case Array::kConstant: {
+ if (print_constants) {
+ string contents = root->as<ConstantArray>()->literal()->ToString();
+ return tensorflow::strings::StrCat(
+ "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents,
+ ")");
+ }
return tensorflow::strings::StrCat(
"(constant ", ShapeUtil::HumanString(root->shape()), ")");
}
@@ -50,26 +59,26 @@ string IndexedArrayAnalysis::ToString(Array* root) {
? "scalar-indexed-const"
: "scalar-indexed";
return tensorflow::strings::StrCat(
- "(", name, " ", ToString(indexed_array->source()), " ",
- ToString(indexed_array->indices()), " ", indexed_array->source_dim(),
- "->[", tensorflow::str_util::Join(indexed_array->output_dims(), ","),
- "])");
+ "(", name, " ", ToString(indexed_array->source(), print_constants),
+ " ", ToString(indexed_array->indices(), print_constants), " ",
+ indexed_array->source_dim(), "->[",
+ Join(indexed_array->output_dims(), ","), "])");
}
}
}
-Analysis::Array* IndexedArrayAnalysis::GetArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
const HloInstruction* instr) {
auto it = cache_.find(instr);
if (it != cache_.end()) {
return it->second;
}
- TraverseAndPopulateCache(instr);
+ TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
return FindOrDie(cache_, instr);
}
-void IndexedArrayAnalysis::TraverseAndPopulateCache(
+Status IndexedArrayAnalysis::TraverseAndPopulateCache(
const HloInstruction* root) {
// Depth first search over the DAG, invoking ComputeArrayFor in post order.
// The HLO instructions already in the cache are considered leaves.
@@ -105,28 +114,46 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache(
case kVisited:
stack.pop_back();
- InsertOrDie(&cache_, instr, ComputeArrayFor(instr));
+ TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
+ InsertOrDie(&cache_, instr, array);
break;
}
} while (!stack.empty());
+
+ return Status::OK();
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
const HloInstruction* instr) {
Array* computed_array;
- switch (instr->opcode()) {
- default:
- computed_array = nullptr;
- break;
- case HloOpcode::kConstant:
- computed_array = ComputeArrayForConstant(instr->literal());
- break;
- case HloOpcode::kGather:
- computed_array = ComputeArrayForGather(
- instr->shape(), instr->gather_dimension_numbers(),
- instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
- FindOrDie(cache_, instr->operand(1)));
- break;
+ if (instr->IsElementwise() && instr->operand_count() == 1) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForElementwiseUnaryOp(
+ instr->opcode(), FindOrDie(cache_, instr->operand(0))));
+ } else if (instr->IsElementwise() && instr->operand_count() == 2) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForElementwiseBinaryOp(
+ instr->opcode(), FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
+ } else if (instr->opcode() == HloOpcode::kConstant) {
+ TF_ASSIGN_OR_RETURN(computed_array,
+ ComputeArrayForConstant(instr->literal()));
+ } else if (instr->opcode() == HloOpcode::kGather) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
+ instr->gather_window_bounds(),
+ FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
+ } else if (instr->opcode() == HloOpcode::kReshape) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForReshape(instr->shape(),
+ FindOrDie(cache_, instr->operand(0))));
+ } else {
+ computed_array = nullptr;
}
if (!computed_array) {
@@ -136,12 +163,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
return computed_array;
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
const Literal& literal) {
return Construct<ConstantArray>(&literal);
}
-ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
+StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
// We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
@@ -161,14 +188,14 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
IndexComponent::Ungathered);
// Simulate the first gather.
- simulated_index.erase(simulated_index.begin() + source->source_dim());
+ EraseAt(&simulated_index, source->source_dim());
for (int64 gather_dim : source->output_dims()) {
simulated_index.insert(simulated_index.begin() + gather_dim,
IndexComponent::GatheredFirst);
}
// Simulate the second gather.
- simulated_index.erase(simulated_index.begin() + source_dim);
+ EraseAt(&simulated_index, source_dim);
for (int64 output_dim : output_dims) {
simulated_index.insert(simulated_index.begin() + output_dim,
IndexComponent::GatheredSecond);
@@ -207,7 +234,7 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
std::move(shape));
}
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
const Shape& shape, const GatherDimensionNumbers& dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
Array* indices) {
@@ -244,6 +271,443 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather(
shape);
}
+namespace {
+// Returns an index into `values` such that the product of the range
+// [values.begin()+index, values.end()) is equal to `product`. If there is no
+// such index, return -1. All integers in `values` must be positive.
+int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) {
+ DCHECK(c_all_of(values, [](int64 value) { return value > 0; }));
+
+ int64 current_product = 1;
+ int64 i;
+ for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
+ current_product *= values[i];
+ }
+
+ if (product == current_product) {
+ return i + 1;
+ }
+
+ return -1;
+}
+
+struct ReshapePassthroughDimPair {
+ int64 result_dim;
+ int64 operand_dim;
+};
+
+// Returns a set of dimension pairs such for all (result_dim, operand_dim) in
+// the set:
+//
+// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
+//
+// The returned vector of pairs is sorted in both the result_dim and the
+// operand_dim components.
+std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
+ ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape) {
+ // A reshape can be seen as an index mapping from output index to input index:
+ //
+ // (i_0, ..., i_n) = f(o_0, ..., o_m)
+ //
+ // This function returns the pairs (j, k) for which the following invariant
+ // holds for all indices in the shape:
+ //
+ // o_j == i_k
+ //
+ // And this occurs when:
+ //
+ // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m
+ //
+ // (where O_x are the sizes of the output shape and I_x are the sizes of the
+ // input shape) and the size of the dimension j of the result is the same as
+ // the size of dimension k in the operand.
+ //
+ // These conditions are sufficient because the Reshape HLO is spec'ed such
+ // that the rightmost dimensions are always minor in the flattening and refine
+ // operation.
+
+ std::vector<ReshapePassthroughDimPair> result;
+ int64 result_subarray_size = 1;
+ for (int64 result_dim = result_shape.size() - 1; result_dim >= 0;
+ --result_dim) {
+ int64 candidate_operand_dim =
+ FindSuffixWithProduct(operand_shape, result_subarray_size);
+
+ // result_subarray_size does not include the elements in the current
+ // `result_dim` dimension (we multiply in result_shape[result_dim] at the
+ // end of loop body) so candidate_operand_dim can never be zero.
+ CHECK_NE(candidate_operand_dim, 0);
+
+ if (candidate_operand_dim != -1 &&
+ result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
+ result.push_back({/*result_dim=*/result_dim,
+ /*operand_dim=*/candidate_operand_dim - 1});
+ }
+ result_subarray_size *= result_shape[result_dim];
+ }
+
+ c_reverse(result);
+
+ if (VLOG_IS_ON(3)) {
+ std::vector<string> result_strings;
+ c_transform(result, std::back_inserter(result_strings),
+ [](ReshapePassthroughDimPair value) {
+ return tensorflow::strings::StrCat(value.result_dim, "->",
+ value.operand_dim);
+ });
+ VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
+ << Join(result_shape, ",") << "] passthrough indices are ["
+ << Join(result_strings, ",") << "]";
+ }
+
+ DCHECK(c_is_sorted(
+ result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
+ return lhs.result_dim < rhs.result_dim;
+ }));
+
+ DCHECK(c_is_sorted(
+ result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
+ return lhs.operand_dim < rhs.operand_dim;
+ }));
+
+ return result;
+}
+
+// Return true if `dim` is stated as an passthrough operand dim in
+// `passthrough_dims`.
+bool IsReshapePassthroughOperandDim(
+ ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) {
+ return c_any_of(passthrough_dims,
+ [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == dim;
+ });
+}
+
+// Maps `operand_dim` which must be an passthrough operand dimension to its
+// corresponding passthrough result dimension based on `passthrough_dims`.
+int64 MapPassthroughOperandDimToResultDim(
+ ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) {
+ auto it = c_find_if(passthrough_dims,
+ [&](ReshapePassthroughDimPair passthrough_dim_pair) {
+ return passthrough_dim_pair.operand_dim == operand_dim;
+ });
+ CHECK(it != passthrough_dims.end());
+ return it->result_dim;
+}
+
+int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
+ ArraySlice<int64> result_shape,
+ int64 source_passthrough_dim) {
+ int64 indexed_source_subarray_size =
+ std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
+ operand_shape.end(), 1, std::multiplies<int64>());
+
+ return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
+}
+
+}; // namespace
+
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
+ const Shape& shape, Array* operand) {
+ auto* scalar_indexed = dynamic_cast<ScalarIndexedConstantArray*>(operand);
+ if (!scalar_indexed) {
+ return nullptr;
+ }
+
+ // Try to fold Reshape(ScalarIndexed(Const, Indices))
+ // => ScalarIndexed(Const', Indices)
+ //
+ // We can view the reshape and the scalar-indexed operations as functions that
+ // map an output index (i.e. an index into the result) to an input index
+ // (i.e. an index into the operand). The key idea used here is that the
+ // output-to-input mapping for some reshape operations may "pass through" some
+ // output dimensions into the input space unchanged -- i.e. there may exist
+ // output dimension "O" and input dimension "I" such that OutputIndex[O] is
+ // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through
+ // dimensions in the input space of the reshape happen to be include all the
+ // output dimensions for the scalar-indexed node then, roughly, the following
+ // holds:
+ //
+ // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
+ // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
+ //
+ // Where Ps are the set of the pass-through components of Idx that are
+ // also the output dims of the scalar-indexed node, and Qs are the rest.
+ // For brevity, we're playing fast and loose with the notation here -- we
+ // don't literally require Idx to be a concatenation of Ps and Qs, as
+ // suggested by the "++".
+ //
+ // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
+ //
+ // Again, we're playing fast and loose with the notation around "++".
+ // Generally this ++ will be a different function that the ++ in the
+ // previous step.
+ //
+ // If the scalar-indexed node has a constant as the source then the
+ // SourceIndexOfReshape function can be "folded into" the constant itself by
+ // reshaping it, leaving us with:
+ //
+ // == SourceIndexOfScalarIndexed(Ps ++ Qs)
+ // == SourceIndexOfScalarIndexed(Idx)
+ //
+ // which is just a scalar-indexed node (with parameters different from the
+ // scalar-indexed node we started with) with a reshaped constant as the
+ // source.
+ //
+ // We can't fold SourceIndexOfReshape into the constant without introducing
+ // another precondition: since the new scalar-indexed node will have a
+ // reshaped (constant) array as its source it will, in general, have a
+ // different source dimension than the original scalar-indexed node. This
+ // source dimension will have to be a passthrough dimension of the
+ // SourceIndexOfReshape indexing function that is folded into the source. And
+ // such a dimension need not exist so this is a non-trivial precondition.
+
+ std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
+ ComputeReshapePassthroughDimPairs(
+ /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()),
+ /*result_shape=*/AsInt64Slice(shape.dimensions()));
+
+ auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) {
+ return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
+ operand_dim);
+ };
+
+ if (!c_all_of(scalar_indexed->output_dims(),
+ is_reshape_passthrough_operand_dim)) {
+ return nullptr;
+ }
+
+ // To compute the shape of the source for the new scalar-indexed node we're
+ // going to create, we first "undo" the scalar-indexed operation.
+ std::vector<int64> new_scalar_indexed_source_shape(shape.dimensions().begin(),
+ shape.dimensions().end());
+ for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
+ int64 output_dim = scalar_indexed->output_dims()[i];
+ int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
+ reshape_passthrough_dims, output_dim);
+ EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
+ }
+
+ // After this, we need to add in the dimension that will be the source
+ // dimension for the new scalar-indexed node. A scalar-indexed node "removes"
+ // the source dimensions and "adds" the output dimensions, so to get back to
+ // the shape for the *source* of the scalar-indexed node we need to remove the
+ // output dims (which we did above) and then add back the source dim (which we
+ // are about to do below):
+
+ const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
+
+ int64 source_dim_for_new_scalar_indexed_node =
+ FindSourcePositionForPassthroughResultDim(
+ /*operand_shape=*/AsInt64Slice(
+ scalar_indexed_source_shape.dimensions()),
+ /*result_shape=*/new_scalar_indexed_source_shape,
+ scalar_indexed->source_dim());
+
+ // We may not be able to find a source dim for the new scalar-indexed node.
+ // For instance consider:
+ //
+ // operand = s32[3,5,2] constant({...})
+ // indices = s32[7] parameter(0)
+ // gather = s32[3,2,7] gather(operand, indices),
+ // output_window_dims={0,1},
+ // elided_window_dims={1},
+ // gather_dims_to_operand_dims={1},
+ // index_vector_dim=1,
+ // window_bounds={3,1,2}
+ // reshape = s32[6,7] reshape(gather)
+ //
+ // In this case the gather maps to:
+ // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
+ //
+ // and the reshape passes through dimension 2 from its input into dimension 1
+ // in its output. However, we can't rewrite the reshape as a scalar-indexed
+ // node because then we'd have to reshape the [3,5,2] `operand` array to
+ // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
+ // (a.k.a. isn't pass-through) than the [3,5,2] array.
+
+ if (source_dim_for_new_scalar_indexed_node == -1) {
+ return nullptr;
+ }
+
+ InsertAt(
+ &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
+ scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
+
+ CHECK(IsReshapePassthroughOperandDim(
+ ComputeReshapePassthroughDimPairs(
+ /*operand_shape=*/AsInt64Slice(
+ scalar_indexed_source_shape.dimensions()),
+ /*result_shape=*/new_scalar_indexed_source_shape),
+ scalar_indexed->source_dim()));
+
+ auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) {
+ return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
+ result_dim);
+ };
+
+ std::vector<int64> output_dims_for_new_scalar_indexed_node;
+ c_transform(scalar_indexed->output_dims(),
+ std::back_inserter(output_dims_for_new_scalar_indexed_node),
+ map_passthrough_operand_dim_to_result_dim);
+
+ TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
+ TakeOwnership(scalar_indexed->literal().Reshape(
+ new_scalar_indexed_source_shape)));
+ TF_ASSIGN_OR_RETURN(
+ Array * new_scalar_indexed_source,
+ ComputeArrayForConstant(*new_scalar_indexed_source_literal));
+
+ return ConstructScalarIndexedArray(
+ new_scalar_indexed_source, scalar_indexed->indices(),
+ source_dim_for_new_scalar_indexed_node,
+ output_dims_for_new_scalar_indexed_node, shape);
+}
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+ Array* lhs,
+ Array* rhs) {
+ // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
+ // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
+ //
+ // We can do this if every output dimension from the scalar-indexed node is a
+ // broadcasted dimension for the broadcast node. Informally, the precondition
+ // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
+ // that are not output-dims for the scalar-indexed node. In other words, for
+ // every assignment to the non-output dims in IDX we have a "constant" LHS to
+ // the BinaryOp. This transform propagates this "constant" to the source for
+ // the scalar-indexed node.
+
+ ScalarIndexedConstantArray* lhs_scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(lhs);
+ ScalarIndexedConstantArray* rhs_scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(rhs);
+
+ bool lhs_is_indexed;
+
+ // One of the operands must be scalar-indexed and the other must be a
+ // broadcast of a constant.
+ if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
+ lhs_is_indexed = true;
+ } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
+ lhs_is_indexed = false;
+ } else {
+ return nullptr;
+ }
+
+ ScalarIndexedConstantArray* scalar_indexed_const =
+ lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
+ UnknownArray* candidate_broadcast_array =
+ dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
+ if (!candidate_broadcast_array ||
+ candidate_broadcast_array->instruction().opcode() !=
+ HloOpcode::kBroadcast) {
+ return nullptr;
+ }
+
+ const HloInstruction* broadcast_instr =
+ &candidate_broadcast_array->instruction();
+ const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
+ if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
+ return nullptr;
+ }
+
+ ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions();
+ auto is_broadcasted_dim = [&](int64 output_dim) {
+ return c_find(broadcast_dims, output_dim) == broadcast_dims.end();
+ };
+
+ // All of the output dims must be "broadcasted" dims for the other operand.
+ if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) {
+ return nullptr;
+ }
+
+ // To figure out the broadcast dimensions for the (constant) source for the
+ // scalar-indexed node, we "simulate" the index transformation done by the
+ // existing broadcsat:
+ enum class IndexComponent { Broadcasted, NotBroadcasted };
+ std::vector<IndexComponent> simulated_index(
+ broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
+ for (int64 broadcast_dim : broadcast_dims) {
+ simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
+ }
+
+ // The scalar-indexed node "removes" the source dim and "inserts" the output
+ // dims. We do the opposite here to undo the scalar-indexed operation.
+ ArraySlice<int64> output_dims = scalar_indexed_const->output_dims();
+ for (int64 i = output_dims.size() - 1; i >= 0; --i) {
+ CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
+ EraseAt(&simulated_index, output_dims[i]);
+ }
+
+ InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
+ IndexComponent::Broadcasted);
+
+ // new_inner_broadcast_dims holds the broadcast dimensions for the inner
+ // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to
+ // new_inner_broadcast_dims.
+ std::vector<int64> new_inner_broadcast_dims;
+ for (int64 i = 0; i < simulated_index.size(); i++) {
+ if (simulated_index[i] == IndexComponent::NotBroadcasted) {
+ new_inner_broadcast_dims.push_back(i);
+ }
+ }
+
+ // inner_broadcast_result is the Broadcast'(Const0) bit in
+ // BinaryOp(Broadcast'(Const0), Const1)
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Literal> inner_broadcast_result,
+ broadcast_const_operand->literal().Broadcast(
+ scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
+
+ // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
+ const Literal* literal_for_new_source;
+ if (lhs_is_indexed) {
+ TF_ASSIGN_OR_RETURN(
+ literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+ opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+ opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+ }
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, scalar_indexed_const->indices(),
+ scalar_indexed_const->source_dim(),
+ std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+ scalar_indexed_const->output_dims().end()),
+ scalar_indexed_const->shape());
+}
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+ Array* operand) {
+ auto* scalar_indexed_const =
+ dynamic_cast<ScalarIndexedConstantArray*>(operand);
+ if (scalar_indexed_const == nullptr) {
+ return nullptr;
+ }
+
+ // Fold UnaryOp(ScalarIndexed(Const, Indices))
+ // => ScalarIndexed(UnaryOp(Const), Indices)
+
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
+ opcode, scalar_indexed_const->literal())));
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, scalar_indexed_const->indices(),
+ scalar_indexed_const->source_dim(),
+ std::vector<int64>(scalar_indexed_const->output_dims().begin(),
+ scalar_indexed_const->output_dims().end()),
+ scalar_indexed_const->shape());
+}
+
tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
@@ -256,7 +720,7 @@ StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
IndexedArrayAnalysis analysis;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto* instr : computation->instructions()) {
- auto* t = analysis.GetArrayFor(instr);
+ TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t);
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index b132a8f251..ce92fd2919 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -143,8 +143,8 @@ class IndexedArrayAnalysis {
//
// For example, if source is of shape [11,13,17,19], indices is of shape
// [23,29], output_dims is [0,2] and source_dim is 2 then the output is of
- // shape [23,11,29,19] and the output index [A,B,C,D,E] is mapped to the input
- // index [B,D,indices[A,C],E].
+ // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the
+ // input index [B,D,indices[A,C],E].
class ScalarIndexedArray : public Array {
public:
Kind kind() const override { return kScalarIndexed; }
@@ -152,7 +152,15 @@ class IndexedArrayAnalysis {
Array* source() const { return source_; }
Array* indices() const { return indices_; }
+
+ // `source_dim` is the dimension in the source array that is being indexed
+ // over using indices from the `indices` array. See the class documentation
+ // and the overview for more details.
int64 source_dim() const { return source_dim_; }
+
+ // `output_dims` are the dimensions in the output array that are being used
+ // to compute an index into the `indices` array. See the class
+ // documentation and the overview for more details.
tensorflow::gtl::ArraySlice<int64> output_dims() const {
return output_dims_;
}
@@ -212,26 +220,26 @@ class IndexedArrayAnalysis {
// NB! By inspecting the implementation, you may be able to infer a stronger
// caching guarantee than what is mentioned above. Nevertheless, what is
// stated above is the contract.
- Array* GetArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
// Pretty-prints the expression rooted at `root`.
- string ToString(Array* root);
+ string ToString(Array* root, bool print_constants = false);
private:
// Helper function that ensures that every HLO instruction that is
// transitively used by `root` has an entry in `cache_`.
- void TraverseAndPopulateCache(const HloInstruction* root);
+ Status TraverseAndPopulateCache(const HloInstruction* root);
// Creates an Array instance for `instr` under the assumption that all
// operations of `instr` are present in `cache_`.
- Array* ComputeArrayFor(const HloInstruction* instr);
+ StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
- Array* ComputeArrayForConstant(const Literal& literal);
+ StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
- Array* ComputeArrayForGather(const Shape& shape,
- const GatherDimensionNumbers& dim_numbers,
- tensorflow::gtl::ArraySlice<int64> window_bounds,
- Array* source, Array* indices);
+ StatusOr<Array*> ComputeArrayForGather(
+ const Shape& shape, const GatherDimensionNumbers& dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+ Array* indices);
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
@@ -254,10 +262,17 @@ class IndexedArrayAnalysis {
//
// I2 = [I0[i] for i in I1]
// G1 = [Arr[i] for i in I2]
- ScalarIndexedArray* FoldGatherOfGather(
+ StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+ StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
+
+ StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
+ Array* lhs, Array* rhs);
+ StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
+ Array* operand);
+
template <typename T, typename... Args>
T* Construct(Args&&... args) {
T* new_tensor = new T(std::forward<Args>(args)...);
@@ -279,6 +294,19 @@ class IndexedArrayAnalysis {
}
}
+ Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+ owned_literals_.push_back(std::move(literal));
+ return owned_literals_.back().get();
+ }
+
+ StatusOr<Literal*> TakeOwnership(
+ StatusOr<std::unique_ptr<Literal>> literal_or_error) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ std::move(literal_or_error));
+ owned_literals_.push_back(std::move(literal));
+ return owned_literals_.back().get();
+ }
+
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<std::unique_ptr<Literal>> owned_literals_;
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index b2731b7c51..373556ebeb 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -23,14 +23,31 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
+ AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+ /*print_constants=*/false);
+ }
+
+ void AssertArrayWithConstantsForRootExpressionIs(
+ const string& hlo_text, const string& root_expression) {
+ AssertArrayForRootExpressionIsImpl(hlo_text, root_expression,
+ /*print_constants=*/true);
+ }
+
+ private:
+ void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
+ const string& root_expression,
+ bool print_constants) {
IndexedArrayAnalysis indexed_tensor_analysis;
ParseAndVerifyModule(hlo_text);
- string result =
- indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor(
+ TF_ASSERT_OK_AND_ASSIGN(
+ IndexedArrayAnalysis::Array* const array_result,
+ indexed_tensor_analysis.GetArrayFor(
module().entry_computation()->root_instruction()));
- LOG(INFO) << result;
- ASSERT_EQ(result, root_expression);
+ string string_result =
+ indexed_tensor_analysis.ToString(array_result, print_constants);
+ LOG(INFO) << string_result;
+ ASSERT_EQ(string_result, root_expression);
}
};
@@ -187,5 +204,301 @@ ENTRY main {
"(scalar-indexed %operand (scalar-indexed %indices_a %indices_b "
"1->[0,2]) 1->[0,1,3])");
}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT reshape = s32[5,2,2] reshape(gather)
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
+ indices = s32[5,7] parameter(0)
+ gather = s32[5,4,7] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,4}
+ ROOT reshape = s32[5,2,2,7] reshape(gather)
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text,
+ "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,2,6] constant(s32[3,2,6]{
+ {{1,2,3,4,5,6},{1,2,3,4,5,6}},
+ {{1,2,3,4,5,6},{1,2,3,4,5,6}},
+ {{1,2,3,4,5,6},{1,2,3,4,5,6}}})
+ indices = s32[5,7] parameter(0)
+ gather = s32[5,2,6,7] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,2,6}
+ ROOT reshape = s32[5,3,4,7] reshape(gather)
+}
+)";
+
+ AssertArrayForRootExpressionIs(
+ hlo_text,
+ "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])");
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}})
+ indices = s32[5,6] parameter(0)
+ gather = s32[5,4,6] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,4}
+ ROOT reshape = s32[5,2,2,2,3] reshape(gather)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,5,2] constant(s32[3,5,2]{
+ {{1,2},{3,4},{5,6},{7,8},{9,10}},
+ {{1,2},{3,4},{5,6},{7,8},{9,10}},
+ {{1,2},{3,4},{5,6},{7,8},{9,10}}})
+ indices = s32[7] parameter(0)
+ gather = s32[3,2,7] gather(operand, indices),
+ output_window_dims={0,1},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1,2}
+ ROOT reshape = s32[6,7] reshape(gather)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+}
+
+TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
+ string hlo_text = R"(
+HloModule UnaryOpOfGather
+
+ENTRY main {
+ operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ indices = s32[5] parameter(0)
+ gather = f32[5,4] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT tanh = f32[5,4] tanh(gather)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant f32[3,4] f32[3,4] {
+ { 0.761594176, 0.964027584, 0.995054781, 0.999329329 },
+ { 0.761594176, 0.995054781, 0.964027584, 0.999329329 },
+ { 0.999329329, 0.995054781, 0.964027584, 0.761594176 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) {
+ string hlo_text = R"(
+HloModule AddBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 6, 7, 8, 9 },
+ { 6, 8, 7, 9 },
+ { 9, 8, 7, 6 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+ SubtractBroadcastedScalarWithGather_GatherIsLhs) {
+ string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT sub = s32[5,4] subtract(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { -4, -3, -2, -1 },
+ { -4, -2, -3, -1 },
+ { -1, -2, -3, -4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest,
+ SubtractBroadcastedScalarWithGather_GatherIsRhs) {
+ string hlo_text = R"(
+HloModule SubtractBroadcastedScalarWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant = s32[] constant(5)
+ constant_broadcasted = s32[5,4] broadcast(constant), dimensions={}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT sub = s32[5,4] subtract(constant_broadcasted, gather)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 4, 3, 2, 1 },
+ { 4, 2, 3, 1 },
+ { 1, 2, 3, 4 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) {
+ string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant_vect = s32[4] constant({10,11,12,13})
+ constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"(
+(scalar-indexed-const (constant s32[3,4] s32[3,4] {
+ { 11, 13, 15, 17 },
+ { 11, 14, 14, 17 },
+ { 14, 14, 14, 14 }
+}) %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) {
+ string hlo_text = R"(
+HloModule AddBroadcastedVectorWithGather
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}})
+ constant_vect = s32[5] constant({10,11,12,13,14})
+ constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0}
+ indices = s32[5] parameter(0)
+ gather = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT add = s32[5,4] add(gather, constant_broadcasted)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%add");
+}
+
+TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) {
+ string hlo_text = R"(
+HloModule RegularUnaryOp
+
+ENTRY main {
+ input = f32[100] parameter(0)
+ ROOT tanh = f32[100] tanh(input)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%tanh");
+}
+
+TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) {
+ string hlo_text = R"(
+HloModule RegularUnaryOp
+
+ENTRY main {
+ input0 = f32[100] parameter(0)
+ input1 = f32[100] parameter(1)
+ ROOT add = f32[100] add(input0, input1)
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%add");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index cb6c98c481..1912b8f2c7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -178,8 +178,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
- const HloReachabilityMap& reachability_map,
- const DoNotFuseSet& do_not_fuse) {
+ const HloInstructionSet& do_not_duplicate) {
if (consumer == producer) {
return true;
}
@@ -190,10 +189,11 @@ bool InstructionFusion::CanFuseOnAllPaths(
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
// whether it's fusable.
- if (!reachability_map.IsReachable(producer, consumer_operand)) {
+ if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
- if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
+ if (do_not_duplicate.count(consumer_operand) > 0 ||
+ !ShouldFuse(consumer, i)) {
return false;
}
// The producer is reachable from consumer_operand which means we need
@@ -201,18 +201,16 @@ bool InstructionFusion::CanFuseOnAllPaths(
// producer to be fusable into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
- if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map,
- do_not_fuse)) {
+ if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
return false;
}
}
return true;
}
-InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable(
+InstructionFusion::HloInstructionSet
+InstructionFusion::ComputeGloballyUnfusable(
tensorflow::gtl::ArraySlice<HloInstruction*> post_order) {
- auto reachability = computation_->ComputeReachability();
-
// Forbid fusion of producers that:
// a) Need to be duplicated, unless they can be fused into all consumers
// via all paths.
@@ -222,10 +220,10 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable(
// Note that if we allow fusion by these global rules, we may still forbid
// fusing operations that require duplication later depending on
// is_expensive_().
- DoNotFuseSet do_not_fuse;
+ HloInstructionSet do_not_duplicate;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
- if (do_not_fuse.count(producer) > 0) {
+ if (do_not_duplicate.count(producer) > 0) {
continue;
}
@@ -254,14 +252,14 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable(
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
if (producer->IsFusable() &&
- CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) {
+ CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
continue;
}
- do_not_fuse.insert(producer);
+ do_not_duplicate.insert(producer);
}
}
- return do_not_fuse;
+ return do_not_duplicate;
}
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
@@ -273,6 +271,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
for (auto* computation : module->MakeNonfusionComputations()) {
CHECK(!computation->IsFusionComputation());
computation_ = computation;
+ reachability_ = computation_->ComputeReachability();
// We want to be able to remove arbitrary instructions from the post order
// and also compare positions of instructions in the post order. To make
@@ -290,7 +289,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
InsertOrDie(&post_order_index, post_order[i], i);
}
- DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order);
+ HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order);
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
@@ -358,9 +357,20 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// ensures that B will be considered before A.
//
// We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers(instruction->operands().size());
- std::iota(std::begin(sorted_operand_numbers),
- std::end(sorted_operand_numbers), 0);
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the post-order index.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (post_order_index.find(instruction->mutable_operand(i)) ==
+ post_order_index.end()) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
std::sort(
sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
[&](int64 i, int64 j) {
@@ -377,13 +387,20 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
if (!operand->IsFusable()) {
continue;
}
- if (!ShouldFuse(instruction, i)) {
- continue;
- }
- if (do_not_fuse.count(operand) > 0) {
+
+ HloInstruction* fusion_instruction;
+ // Try "regular" fusion if the operand may be duplicated. Otherwise,
+ // perform multi-output fusion, unless this creates a cycle.
+ // TODO(tjoerg): Consider making multi-output fusion the default.
+ if (ShouldFuse(instruction, i) &&
+ do_not_duplicate.count(operand) == 0) {
+ fusion_instruction = Fuse(operand, instruction);
+ } else if (ShouldFuseIntoMultiOutput(instruction, i) &&
+ !MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_instruction = FuseIntoMultiOutput(operand, instruction);
+ } else {
continue;
}
- HloInstruction* fusion_instruction = Fuse(operand, instruction);
// Fusing an instruction into a fusion instruction can change the
// operand set of the fusion instruction. For simplicity just push the
@@ -449,6 +466,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput(
return fusion_instruction;
}
+bool InstructionFusion::MultiOutputFusionCreatesCycle(
+ HloInstruction* producer, HloInstruction* consumer) {
+ return c_any_of(
+ consumer->operands(), [&](const HloInstruction* consumer_operand) {
+ // The fusion algorithm traverses the HLO graph in reverse post order.
+ // Thus `cosumers` is visited before its operands (including
+ // `producer`). Therefore, consumer operands cannot have been fused yet.
+ // It is thus safe to use the pre-computed reachability map.
+ return consumer_operand != producer &&
+ reachability_->IsReachable(producer, consumer_operand);
+ });
+}
+
bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c3c2ed0aaa..f73ca9adf7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface {
// Subtypes can override this with target-specific heuristics.
virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index);
+ // Returns whether multi-output fusion can be applied to fuse `producer` into
+ // `consumer`. In contrast to "regular" fusion, the `producer` is not
+ // duplicated by multi-output fusion.
+ virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer,
+ int64 operand_index) {
+ return false;
+ }
+
// Chooses a fusion kind for `producer` and `consumer`.
// Default method chooses `kLoop`.
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
@@ -97,10 +105,12 @@ class InstructionFusion : public HloPassInterface {
// Current HloComputation instance the loop fuser is traversing.
HloComputation* computation_;
HloModule* module_;
+ // Reachability information for the current computation.
+ std::unique_ptr<HloReachabilityMap> reachability_;
private:
// The set of producers whose consumers we cannot fuse into.
- using DoNotFuseSet = std::unordered_set<HloInstruction*>;
+ using HloInstructionSet = std::unordered_set<HloInstruction*>;
HloInstruction* AddFusionInstruction(HloInstruction* producer,
HloInstruction* consumer);
@@ -108,18 +118,21 @@ class InstructionFusion : public HloPassInterface {
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
- const HloReachabilityMap& reachability_map,
- const DoNotFuseSet& do_not_fuse);
+ const HloInstructionSet& do_not_fuse);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
- DoNotFuseSet ComputeGloballyUnfusable(
+ HloInstructionSet ComputeGloballyUnfusable(
tensorflow::gtl::ArraySlice<HloInstruction*> post_order);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
std::function<bool(const HloInstruction& instruction)> is_expensive_;
+ // Whether multi-output fusion would introduce a cycle into the HLO graph.
+ bool MultiOutputFusionCreatesCycle(HloInstruction* producer,
+ HloInstruction* consumer);
+
// Returns whether we may duplicate an instruction if we want to fuse it.
bool may_duplicate_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 3a21eda357..5fc08aab91 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -24,15 +24,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace llvm_ir {
-void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
- llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
- llvm::Module* module) {
+void EmitTupleSelect(const IrArray& select, const IrArray& pred,
+ llvm::Value* on_true, llvm::Value* on_false,
+ llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
llvm::LoadInst* pred_value =
@@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
VLOG(2) << " pred_cond: " << DumpToString(*pred_cond);
for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) {
- std::vector<llvm::Value*> element_index = {ir_builder->getInt64(0),
- ir_builder->getInt64(i)};
+ llvm::Value* const element_index[] = {ir_builder->getInt64(0),
+ ir_builder->getInt64(i)};
llvm::Value* on_true_element_address =
ir_builder->CreateInBoundsGEP(on_true, element_index);
llvm::Value* on_true_element = ir_builder->CreateLoad(
- on_true_element_address,
- tensorflow::strings::Printf("on_true_element_%d", i).c_str());
+ on_true_element_address, "on_true_element_" + llvm::Twine(i));
llvm::Value* on_false_element_address =
ir_builder->CreateInBoundsGEP(on_false, element_index);
llvm::Value* on_false_element = ir_builder->CreateLoad(
- on_false_element_address,
- tensorflow::strings::Printf("on_false_element_%d", i).c_str());
+ on_false_element_address, "on_false_element_" + llvm::Twine(i));
llvm::Value* output_element_address =
ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index);
ir_builder->CreateStore(
- ir_builder->CreateSelect(
- pred_cond, on_true_element, on_false_element,
- tensorflow::strings::Printf("select_output_element_%d", i).c_str()),
+ ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element,
+ "select_output_element_" + llvm::Twine(i)),
output_element_address);
}
}
-void EmitTuple(IrArray tuple,
+void EmitTuple(const IrArray& tuple,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index dbf9a14006..352d34ebf8 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -59,13 +59,13 @@ namespace llvm_ir {
// of the address from the corresponding element in either
// tuple_on_true or tuple_on_false:
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
-void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
- llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
- llvm::Module* module);
+void EmitTupleSelect(const IrArray& select, const IrArray& pred,
+ llvm::Value* on_true, llvm::Value* on_false,
+ llvm::IRBuilder<>* ir_builder, llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
-void EmitTuple(IrArray tuple,
+void EmitTuple(const IrArray& tuple,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
llvm::IRBuilder<>* ir_builder, llvm::Module* module);
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 321fdeb1ea..09ddcffb22 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy(
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
-static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
+bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
+ const HloInstruction& instruction) {
switch (instruction.opcode()) {
default:
return false;
+ case HloOpcode::kConstant:
+ return !hoist_constants_;
+
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
- case HloOpcode::kConstant:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
@@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
}
}
-static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+StatusOr<bool>
+WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
HloInstruction* while_instr) {
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
@@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
}
- if (unhoisted_invariant_instructions.empty()) {
+ if (unhoisted_invariant_instructions.empty() && !hoist_constants_) {
// There are no obviously loop invariant elements in the state being
// threaded through the while loop so give up. In theory this precondition
// is too strong -- we could have code that e.g. permutes the elements in
// the while state but uses a select to pick the same value on every
// iteration.
+ //
+ // If we were asked to hoist constants, we need to scan the while body for
+ // constants even if we didn't find any loop invariant values in the while
+ // state tuple.
return false;
}
@@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
+ VLOG(2) << "HLO module before WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
@@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
changed |= result;
}
+
+ if (changed) {
+ VLOG(2) << "HLO module after WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+ } else {
+ VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
+ }
+
return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8c4b765b00..8e6cc87875 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -27,12 +27,28 @@ namespace xla {
class WhileLoopInvariantCodeMotion : public HloPassInterface {
public:
+ // If `hoist_constants` is true then constants are always hoisted out of while
+ // loop bodies. Otherwise they are only hoisted out if they enable other
+ // non-trivial computations to be hoisted out.
+ //
+ // Setting `hoist_constants` to false can be help if LICM is run in the mid
+ // level HLO pipeline because hoisting constants out of while loop bodies can
+ // break optimizations like constant folding.
+ explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false)
+ : hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
tensorflow::StringPiece name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool NotWorthHoistingIndividually(const HloInstruction& instruction);
+ StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+ HloInstruction* while_instr);
+
+ bool hoist_constants_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 799340fda9..e1ec12192f 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
EXPECT_FALSE(simplified_loop);
}
+const char* const kConstantHoistingTestCase = R"(
+HloModule ModuleWithWhile
+
+body {
+ p_body = (f32[2]{0}) parameter(0)
+ p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
+ const = f32[2]{0} constant({3, 4})
+ add.0 = f32[2]{0} add(p_body.1, const)
+ ROOT root = (f32[2]{0}) tuple(add.0)
+}
+
+condition {
+ p_cond = (f32[2]{0}) parameter(0)
+ ROOT result = pred[] constant(true)
+}
+
+ENTRY entry {
+ const_0 = f32[2]{0} constant({1, 2})
+ while_init = (f32[2]{0}) tuple(const_0)
+ ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body
+}
+)";
+
+TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool simplified_loop,
+ WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module()));
+ EXPECT_TRUE(simplified_loop);
+
+ HloComputation* while_body = module().GetComputationWithName("wide.body");
+ ASSERT_NE(while_body, nullptr);
+
+ // We expect the while body to be the equivalent of:
+ //
+ // wide.body {
+ // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0)
+ // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0
+ // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1)
+ // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0
+ // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7)
+ // tuple.3 = (f32[2]{0}) tuple(add.1)
+ // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0
+ // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8,
+ // get-tuple-element.9)
+ // }
+
+ auto wide_param_1 = op::Parameter(0);
+ auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0);
+ auto tuple_1 = op::Tuple(get_tuple_element_1);
+ auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0);
+ auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1);
+ auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7);
+ auto tuple_3 = op::Tuple(add_1);
+ auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0);
+ auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1);
+ auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9);
+
+ EXPECT_THAT(while_body->root_instruction(), tuple_4);
+}
+
+TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+ WhileLoopInvariantCodeMotion{}.Run(&module()));
+ EXPECT_FALSE(simplified_loop);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 7a897f6f8f..2cdee30340 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -55,6 +55,23 @@ string ShapeIndexView::ToString() const {
"}");
}
+bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
+ if (size() != other.size()) {
+ return false;
+ }
+ for (auto it = begin(), other_it = other.begin(); it != end();
+ ++it, ++other_it) {
+ if (*it != *other_it) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
+ return !(*this == other);
+}
+
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
out << shape_index.ToString();
return out;
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 82c75f85d8..cf40068b33 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -132,6 +133,9 @@ class ShapeIndexView {
return ShapeIndexView(new_begin, end_);
}
+ bool operator==(const ShapeIndexView& other) const;
+ bool operator!=(const ShapeIndexView& other) const;
+
string ToString() const;
private:
@@ -626,6 +630,28 @@ class ShapeUtil {
.IgnoreError();
}
+ // These convenience wrappers don't take `base`, `count` and `incr`
+ // explicitly, but iterate over every element in `shape` instead.
+
+ template <typename FnType>
+ static Status ForEachIndexWithStatus(const Shape& shape,
+ const FnType& visitor_function) {
+ std::vector<int64> base(shape.dimensions_size());
+ std::vector<int64> incr(shape.dimensions_size(), 1);
+ return ForEachIndexWithStatus(shape, base,
+ /*count=*/AsInt64Slice(shape.dimensions()),
+ incr, visitor_function);
+ }
+
+ template <typename FnType>
+ static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
+ ForEachIndexWithStatus(shape,
+ [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return StatusOr<bool>(visitor_function(indices));
+ })
+ .IgnoreError();
+ }
+
// A parallel version of ForEachIndex(WithStatus). This can only be used if
// the visitor_function is thread-safe and the order of iteration does not
// matter.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 4883380be1..a62d49e9c7 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -619,6 +619,7 @@ xla_test(
xla_test(
name = "exhaustive_f32_elementwise_op_test",
+ size = "enormous",
srcs = ["exhaustive_f32_elementwise_op_test.cc"],
backends = [
"cpu",
@@ -626,7 +627,6 @@ xla_test(
],
shard_count = 48,
tags = [
- "enormous",
"manual",
"notap",
],
@@ -776,30 +776,42 @@ xla_test(
],
)
+CONVOLUTION_TEST_DEPS = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+]
+
xla_test(
name = "convolution_test",
timeout = "long",
srcs = ["convolution_test.cc"],
shard_count = 25,
- deps = [
- "//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:reference_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:global_data",
- "//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:padding",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
- "//tensorflow/compiler/xla/tests:client_library_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- ],
+ deps = CONVOLUTION_TEST_DEPS,
+)
+
+xla_test(
+ name = "convolution_test_gpu_alternative_layout",
+ timeout = "long",
+ srcs = ["convolution_test.cc"],
+ backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
+ backends = ["gpu"],
+ shard_count = 25,
+ deps = CONVOLUTION_TEST_DEPS,
)
xla_test(
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 4ef0a77884..722d882471 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
-1.99f,
-2.0f,
-2.01f,
- 0x1.FFFFFEp+62F,
- 0x1.FFFFFCp+62F,
- -0x1.FFFFFEp+62F,
- -0x1.FFFFFCp+62F};
+ 9223371487098961920.f,
+ 9223370937343148032.f,
+ -9223371487098961920.f,
+ -9223370937343148032.f};
std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index ec7ca20bdf..3cbb2452fb 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -273,5 +273,112 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
*result, *Literal::CreateR1<float>({0.0, 4.0, 9.0})));
}
+const char* const kScalarOps = R"(
+ HloModule m
+
+ Add {
+ lhsadd = f32[] parameter(0)
+ rhsadd = f32[] parameter(1)
+ ROOT add = f32[] add(lhsadd, rhsadd)
+ }
+
+ Max {
+ lhsmax = f32[] parameter(0)
+ rhsmax = f32[] parameter(1)
+ ROOT max = f32[] maximum(lhsmax, rhsmax)
+ }
+)";
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
+ calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result,
+ *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}),
+ Literal::CreateR2<float>({{5, 16}, {36, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput,
+ calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result, *Literal::MakeTupleOwned(
+ Literal::CreateR2<float>({{6, 8}, {10, 12}}),
+ Literal::CreateR2<float>({{25, 36}, {49, 64}}))));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
+ r3 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Add
+ ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput,
+ calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ TF_ASSERT_OK_AND_ASSIGN(auto result,
+ Execute(std::move(module), {param.get()}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result, *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}),
+ Literal::CreateR1<float>({36, 64}),
+ Literal::CreateR1<float>({391, 463}))));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 52195db2aa..5653bf11a7 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -197,9 +197,10 @@ class SliceR1Test : public ClientLibraryTestBase,
// vector<bool>.
tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
+ auto literal = Literal::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = builder.ConstantR1<NativeT>(input);
+ auto original = builder.Parameter(0, literal->shape(), "p0");
builder.Slice(original, {spec.slice_start}, {spec.slice_limit},
{spec.slice_stride});
@@ -210,7 +211,9 @@ class SliceR1Test : public ClientLibraryTestBase,
expected.push_back(i);
}
- ComputeAndCompareR1<NativeT>(&builder, expected, {});
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
+ client_->TransferToServer(*literal));
+ ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
}
};
@@ -365,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) {
const R2Spec& spec = GetParam();
Array2D<int32> input(spec.input_dim0, spec.input_dim1);
input.FillUnique();
+ auto literal = Literal::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2FromArray2DWithLayout<int32>(
- input, LayoutUtil::MakeLayout(spec.layout));
+ auto a = builder.Parameter(0, literal->shape(), "p0");
builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
+ client_->TransferToServer(*literal));
std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
- ComputeAndCompareR2<int32>(&builder, *expected, {});
+ ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
}
INSTANTIATE_TEST_CASE_P(
@@ -453,7 +459,7 @@ class SliceR4Test : public ClientLibraryTestBase,
void Run(const R4Spec& spec) {
Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
spec.input_dims[2], spec.input_dims[3]);
- values.FillRandom(3.14f);
+ values.FillIota(3.14159);
auto expected = ReferenceUtil::Slice4D(
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 415cf9c16a..15b9cd4265 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -86,6 +86,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
"//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index d0e7af8844..e990b6aba8 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -56,6 +56,11 @@ class HloParser {
// Returns the error information.
string GetError() const { return Join(error_, "\n"); }
+ // Stand alone parsing for sharding. The parser string is supposed to
+ // contain the body of the sharding, i.e. just the rhs of the "sharding={...}"
+ // attribute string.
+ StatusOr<HloSharding> ParseShardingOnly();
+
private:
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
@@ -2673,6 +2678,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation,
return true;
}
+StatusOr<HloSharding> HloParser::ParseShardingOnly() {
+ lexer_.Lex();
+ OpSharding op_sharding;
+ if (!ParseSharding(&op_sharding)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument("Syntax error:\nExtra content after sharding");
+ }
+ return HloSharding::FromProto(op_sharding);
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
@@ -2689,5 +2706,11 @@ StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
return Parse(str, config);
}
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParseShardingOnly();
+}
+
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h
index 2f97a2b9b1..f7854f403e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h
@@ -36,6 +36,10 @@ StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
// format, parses the string and creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
+// Parse sharding from str. str is supposed to contain the body of the
+// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index df0501386c..2349fa919e 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -63,6 +64,7 @@ namespace {
// fields.
struct Options {
string fake_infeed_shape;
+ bool generate_fake_infeed = false;
bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
@@ -72,8 +74,12 @@ struct Options {
// Invokes the given computation passing arbitrary data for every (unbound)
// parameter if use_fake_data, Otherwise use recorded data if available.
//
-// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
-// otherwise, no infeed is performed.
+// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
+// If generate_fake_infeed is true, the required infeed shape is derived from
+// the computation and then used to provide a fake infeed shape.
+//
+// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
+// no infeed is performed.
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
Client* client,
const Options& opts) {
@@ -92,24 +98,54 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
}
}
+ bool provide_infeed = false;
+ Shape infeed_shape;
+ if (!opts.fake_infeed_shape.empty()) {
+ StatusOr<Shape> shape_status =
+ ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
+ TF_CHECK_OK(shape_status.status());
+ infeed_shape = std::move(shape_status).ValueOrDie();
+ provide_infeed = true;
+ } else if (opts.generate_fake_infeed) {
+ for (const auto& comp : computation.proto().computations()) {
+ for (const auto& instruction : comp.instructions()) {
+ if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) {
+ CHECK(!provide_infeed)
+ << "--generate_fake_infeed only works if the model has 0 or 1 "
+ "infeed ops, but this one has >= 2.";
+ provide_infeed = true;
+ infeed_shape = instruction.shape();
+ LOG(INFO) << "Generating fake infeed shape for inferred shape: "
+ << ShapeUtil::HumanString(infeed_shape);
+ }
+ }
+ }
+ }
// We only instantiate the thread pool if the user has requested that a
- // concurrent infeed occur via the fake_infeed_shape.
+ // concurrent infeed occur via the fake_infeed_shape, or when
+ // --generate_fake_infeed is passed and there exists an infeed operation in
+ // the HloSnapshot.
tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
-
- if (!opts.fake_infeed_shape.empty()) {
+ std::unique_ptr<Literal> data;
+ if (provide_infeed) {
+ data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
+ }
+ auto transfer_infeed = [&data, client]() {
+ TF_CHECK_OK(client->TransferToInfeed(*data));
+ };
+ if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
/*num_threads=*/1);
- pool->Schedule([opts, client]() {
- StatusOr<Shape> shape_status =
- ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
- TF_CHECK_OK(shape_status.status());
- Shape shape = std::move(shape_status).ValueOrDie();
- StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
- TF_CHECK_OK(data_status.status());
- std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
- while (true) {
- TF_CHECK_OK(client->TransferToInfeed(*data));
- }
+ pool->Schedule([transfer_infeed]() {
+ // There may be several infeed buffers needed, however we don't know how
+ // many. If we proactively transfer too many infeed buffers, we may run
+ // out of memory. If we transfer too few infeed buffers, the program will
+ // hang. Therefore, we register a callback that is called when the infeed
+ // becomes empty, and in this callback we will transfer another fake
+ // infeed.
+ auto infeed_manager = xla::gpu::GetOrCreateInfeedManager();
+ infeed_manager->RegisterOnEmptyCallback(transfer_infeed);
+ transfer_infeed();
});
}
@@ -204,6 +240,9 @@ int main(int argc, char** argv) {
"Number of times to run each computation"),
tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
"Shape of fake data to construct for (infinite) infeed"),
+ tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
+ "Whether a fake infeed shape should be generated "
+ "derived from the computation"),
tensorflow::Flag(
"xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
"Pass --xla_hlo_profile the last time we run the computation."),
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index be33bd6dd1..7303640726 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -219,6 +219,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
Status InvalidArgumentV(const char* format, va_list args);
template <typename... Args>
+Status InvalidArgumentStrCat(Args&&... concat) {
+ return InvalidArgument(
+ "%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
+}
+
+template <typename... Args>
Status UnimplementedStrCat(Args&&... concat) {
return Unimplemented(
"%s", tensorflow::strings::StrCat(std::forward<Args>(concat)...).c_str());
@@ -486,6 +492,12 @@ bool c_is_sorted(const C& c) {
return std::is_sorted(std::begin(c), std::end(c));
}
+template <typename C, typename Compare>
+bool c_is_sorted(const C& c, Compare&& comp) {
+ return std::is_sorted(std::begin(c), std::end(c),
+ std::forward<Compare>(comp));
+}
+
template <typename C>
auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) {
return std::adjacent_find(std::begin(c), std::end(c));
@@ -520,6 +532,16 @@ int64 FindIndex(const C& c, Value&& value) {
return std::distance(c.begin(), it);
}
+template <typename C, typename Value>
+void InsertAt(C* c, int64 index, Value&& value) {
+ c->insert(c->begin() + index, std::forward<Value>(value));
+}
+
+template <typename C>
+void EraseAt(C* c, int64 index) {
+ c->erase(c->begin() + index);
+}
+
// Returns true if `x` fits in 32-bits.
template <typename T>
bool IsInt32(T x) {
diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc
index 707853b59b..30de7b59af 100644
--- a/tensorflow/contrib/android/jni/run_stats_jni.cc
+++ b/tensorflow/contrib/android/jni/run_stats_jni.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/android/jni/run_stats_jni.h"
#include <jni.h>
+
#include <sstream>
#include "tensorflow/core/protobuf/config.pb.h"
@@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
StatSummarizer* s = requireHandle(env, handle);
if (s == nullptr) return nullptr;
std::stringstream ret;
- ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10)
+ ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
+ 10)
<< s->GetStatsByNodeType() << s->ShortSummary();
return env->NewStringUTF(ret.str().c_str());
}
diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md
index a7a3fe1452..a4aec8c74a 100644
--- a/tensorflow/contrib/autograph/CONTRIBUTING.md
+++ b/tensorflow/contrib/autograph/CONTRIBUTING.md
@@ -2,6 +2,9 @@
We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
+## TensorFlow Code of Conduct
+Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
+
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
@@ -28,7 +31,7 @@ repository (with credit to the original author) and closes the pull request.
## Style
-See the [TensorFlow AutoGraph style guide](STYLE_GUIDE.md).
+See the [AutoGraph style guide](STYLE_GUIDE.md).
## Unit tests
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md
index 5618ec3e34..866e5f583a 100644
--- a/tensorflow/contrib/autograph/STYLE_GUIDE.md
+++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md
@@ -1,43 +1,26 @@
-# TensorFlow AutoGraph Style Guide
+# AutoGraph Style Guide
-This page contains style decisions that both developers and users of TensorFlow
-AutoGraph should follow to increase the readability of their code, reduce the
-number of errors, and promote consistency. We borrow many style principles from the TensorFlow Probability style guide.
+This page contains style decisions that developers should follow when
+contributing code to AutoGraph.
## TensorFlow Style
Follow the [TensorFlow style
-guide](https://www.tensorflow.org/community/style_guide) and [documentation
-guide](https://www.tensorflow.org/community/documentation). Below are additional
-TensorFlow conventions not noted in those guides. In the future, these noted
-conventions may be moved upstream.
+guide](https://www.tensorflow.org/community/style_guide), the [documentation
+guide](https://www.tensorflow.org/community/documentation) and the
+[Google Python style guide](https://google.github.io/styleguide/pyguide.html).
+
+Naming conventions:
1. The name is TensorFlow, not Tensorflow.
2. The name is AutoGraph, not Autograph.
-## TensorFlow Code of Conduct
-Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
-
-## TensorFlow AutoGraph Style
+## AutoGraph Style
-Below are TensorFlow AutoGraph-specific conventions. In the event of conflict,
+Below are AutoGraph-specific conventions. In the event of conflict,
it supercedes all previous conventions.
-1. __Importing submodule aliases.__ Use the Pythonic style
-`from tensorflow.contrib.autograph.converters import ifexp` and `from tensorflow.contrib import autograph as ag`.
-
-2. __Examples in Docstrings.__ Write a `#### Examples` subsection below `Args`,
- `Returns`, `Raises`, etc. to illustrate examples. If the docstring's last
- line is a fence bracket (\`\`\`) closing a code snippet, add an empty line
- before closing the docstring with \"\"\". This properly displays the code
- snippet.
-
- Justification: Users regularly need to remind themselves of args and
- semantics. But rarely look at examples more than the first time. But since
- examples are usually long (which is great!) it means they have to do a lot
- of annoying scrolling ...unless Examples follow Args/Returns/Raises.
-
-3. __Citations in Docstrings.__ Write a `#### References` subsection at the
+1. __Citations in Docstrings.__ Write a `#### References` subsection at the
bottom of any docstring with citations. Use ICLR’s bibliography style to
write references; for example, order entries by the first author's last
name. Add a link to the paper if the publication is open source (ideally,
@@ -77,21 +60,12 @@ it supercedes all previous conventions.
https://arxiv.org/abs/1803.04386
```
-4. When doing float math over literals eg use `1.` instead of `1` or `1.0`.
-
- * Using `1.` is another line of defense against an automatic casting
- mistake. (Using `1.0` is also such a defense but is not minimal.)
-
-5. Prefer using named args for functions' 2nd args onward.
-
- * Definitely use named args for 2nd args onward in docstrings.
-
-9. Avoid LaTeX in docstrings.
+2. Avoid LaTeX in docstrings.
* It is not rendered in many (if not most) editors and can be hard to read
for both LaTeX experts and non-experts.
-10. Write docstring and comment math using ASCII friendly notation; python using
+3. Write docstring and comment math using ASCII friendly notation; python using
operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`,
`sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx:
x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`.
@@ -99,27 +73,3 @@ it supercedes all previous conventions.
* The more we stick to python style, the more someone can
copy/paste/execute.
* Python style is usually easier to read as ASCII.
-
-11. All public functions require docstrings with: one line description, Args,
- Returns, Raises (if raises exceptions).
-
- * Returns docstrings should be in the same format as Args, eg, of the form
- "name: Description." Part of the rationale is that we are suggesting a
- reasonable variable name for the returned object(s).
-
-12. Regard `*args` and/or `**kwargs` as features of last resort.
-
- * Keyword arguments make the intention of a function call more clear.
- * [Possible exceptions for
- `kwargs`](https://stackoverflow.com/questions/1415812/why-use-kwargs-in-python-what-are-some-real-world-advantages-over-using-named).
-
-18. The `__init__.py` file for modules should use TensorFlow's
- `remove_undocumented` feature, which seals the module's methods.
-
-21. Use `"{}".format()` rather than `"" %` for string formatting.
-
- Justification: [PEP 3101](https://www.python.org/dev/peps/pep-3101/) and
- [Python official
- tutorials](https://docs.python.org/3.2/tutorial/inputoutput.html#old-string-formatting):
- "...this old style of formatting will eventually be removed from the
- language, str.format() should generally be used."
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 35877224b8..5b7508c9a5 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import gast
-
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
@@ -54,13 +52,9 @@ class BreakStatementTransformer(transformer.Base):
def _guard_if_present(self, block, var_name):
"""Prevents the block from executing if var_name is set."""
-
- # If we don't have statements that immediately depend on the break
- # we still need to make sure that the break variable remains
- # used, in case the break becomes useful in later stages of transformation.
- # Not having this broke the break_in_inner_loop test.
if not block:
- block = [gast.Pass()]
+ return block
+
template = """
if not var_name:
block
@@ -73,7 +67,7 @@ class BreakStatementTransformer(transformer.Base):
def visit_While(self, node):
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.context.namer.new_symbol('break__', scope.referenced)
+ break_var = self.context.namer.new_symbol('break_', scope.referenced)
node.test = self.visit(node.test)
node.body, break_used = self._track_body(node.body, break_var)
@@ -81,6 +75,10 @@ class BreakStatementTransformer(transformer.Base):
node.orelse = self.visit_block(node.orelse)
if break_used:
+ # Python's else clause only triggers if the loop exited cleanly (e.g.
+ # break did not trigger).
+ guarded_orelse = self._guard_if_present(node.orelse, break_var)
+
template = """
var_name = False
while test and not var_name:
@@ -88,20 +86,18 @@ class BreakStatementTransformer(transformer.Base):
else:
orelse
"""
- # Python's else clause only triggers if the loop exited cleanly (e.g.
- # break did not trigger).
node = templates.replace(
template,
var_name=break_var,
test=node.test,
body=node.body,
- orelse=self._guard_if_present(node.orelse, break_var))
+ orelse=guarded_orelse)
return node
def visit_For(self, node):
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.context.namer.new_symbol('break__', scope.referenced)
+ break_var = self.context.namer.new_symbol('break_', scope.referenced)
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
@@ -110,19 +106,32 @@ class BreakStatementTransformer(transformer.Base):
node.orelse = self.visit_block(node.orelse)
if break_used:
- node.orelse = self._guard_if_present(node.orelse, break_var)
+ # Python's else clause only triggers if the loop exited cleanly (e.g.
+ # break did not trigger).
+ guarded_orelse = self._guard_if_present(node.orelse, break_var)
+ extra_test = templates.replace_as_expression(
+ 'not var_name', var_name=break_var)
+
+ # The extra test is hidden in the AST, which will confuse the static
+ # analysis. To mitigate that, we insert a no-op statement that ensures
+ # the control variable is marked as used.
+ # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
var_name = False
- for_stmt
+ for target in iter_:
+ (var_name,)
+ body
+ else:
+ orelse
"""
- # Python's else clause only triggers if the loop exited cleanly (e.g.
- # break did not trigger).
node = templates.replace(
template,
var_name=break_var,
- for_stmt=node)
- extra_test = templates.replace_as_expression(
- 'not var_name', var_name=break_var)
+ iter_=node.iter,
+ target=node.target,
+ body=node.body,
+ orelse=guarded_orelse)
+
anno.setanno(node[1], 'extra_test', extra_test)
return node
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index 317711a866..46e39da16a 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -31,9 +31,6 @@ class BuiltinFunctionTransformer(transformer.Base):
TF equivalent, like `len`.
"""
- def __init__(self, context):
- super(BuiltinFunctionTransformer, self).__init__(context)
-
def _convert_builtin(self, node):
template = """
ag__.utils.dynamic_builtin(func, args)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index c00946f9c4..d6555dc7e0 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -136,14 +136,14 @@ class TypeInfoResolver(transformer.Base):
def _process_function_arg(self, arg_name):
str_name = str(arg_name)
+ type_holder = arg_name.ast()
+ self.scope.setval(arg_name, type_holder)
if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types:
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
- type_holder = arg_name.ast()
type_string, type_obj = self.context.arg_types[str_name]
anno.setanno(type_holder, 'type', type_obj)
anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
- self.scope.setval(arg_name, type_holder)
def visit_arg(self, node):
self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
@@ -167,50 +167,41 @@ class TypeInfoResolver(transformer.Base):
anno.getanno(definition, 'element_type'))
return node
- def _process_variable_assignment(self, source, targets):
- # Special case: constructors.
- if isinstance(source, gast.Call):
- func = source.func
+ def _process_variable_assignment(self, target, value):
+ # Constructors
+ if isinstance(value, gast.Call):
+ func = value.func
if anno.hasanno(func, 'live_val'):
func_obj = anno.getanno(func, 'live_val')
if tf_inspect.isclass(func_obj):
- anno.setanno(source, 'is_constructor', True)
- anno.setanno(source, 'type', func_obj)
- anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
+ anno.setanno(value, 'is_constructor', True)
+ anno.setanno(value, 'type', func_obj)
+ anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn'))
# TODO(mdan): Raise an error if constructor has side effects.
# We can have a whitelist of no-side-effects constructors.
# We can also step inside the constructor and further analyze.
- # Multiple targets mean multiple assignment.
- for target in targets:
- # Tuple target means unpacking.
- if isinstance(target, (gast.Tuple, gast.List)):
- for i, target_item in enumerate(target.elts):
- # Two cases here:
- # 1. Static unpacking, e.g. a, b = c, d
- # 2. Dynamic unpacking, e.g. a, b = c
- # The former case is optimized away.
- if isinstance(source, (gast.Tuple, gast.List)):
- source_item = source.elts[i]
- else:
- source_item = gast.Subscript(source, gast.Index(i), ctx=None)
- self._process_variable_assignment(source_item, (target_item,))
- elif isinstance(target, (gast.Name, gast.Attribute)):
- target_symbol = anno.getanno(target, anno.Basic.QN)
- self.scope.setval(target_symbol, source)
- else:
- raise ValueError('assignment target has unknown type: %s' % target)
+ if isinstance(target, (gast.Name, gast.Attribute)):
+ target_symbol = anno.getanno(target, anno.Basic.QN)
+ self.scope.setval(target_symbol, value)
+ elif isinstance(target, gast.Subscript):
+ pass
+ else:
+ raise ValueError('assignment target has unknown type: %s' % target)
def visit_With(self, node):
- for wi in node.items:
- if wi.optional_vars is not None:
- self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
+ for item in node.items:
+ if item.optional_vars is not None:
+ self.apply_to_single_assignments((item.optional_vars,),
+ item.context_expr,
+ self._process_variable_assignment)
self.generic_visit(node)
return node
def visit_Assign(self, node):
self.generic_visit(node)
- self._process_variable_assignment(node.value, node.targets)
+ self.apply_to_single_assignments(
+ node.targets, node.value, self._process_variable_assignment)
return node
def visit_Call(self, node):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
index 46b7701624..95cbf5ca79 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
@@ -196,6 +196,19 @@ class TypeInfoResolverTest(test.TestCase):
f_ref = node.body[0].body[1].value
self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+ def test_type_annotation_args(self):
+
+ class Foo(object):
+ pass
+
+ def test_fn(f):
+ utils.set_element_type(f, Foo)
+ return f
+
+ node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
+ f_ref = node.body[0].body[1].value
+ self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
+
def test_nested_unpacking(self):
class Foo(object):
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index 4db6cc0adf..4c65edb6de 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -103,6 +103,54 @@ class Base(gast.NodeTransformer):
results.append(replacement)
return results
+ # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+ def apply_to_single_assignments(self, targets, values, apply_fn):
+ """Applies a fuction to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: list, tuple of or individual AST node. Should be used with the
+ targets field of an ast.Assign node.
+ values: an AST node.
+ apply_fn: a function of a single argument, which will be called with the
+ respective nodes of each single assignment. The signaure is
+ apply_fn(target, value), no return value.
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
+ self.apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ # TODO(mdan): Look into allowing to rewrite the AST here.
+ apply_fn(target, values)
+
def visit(self, node):
source_code = self.context.source_code
source_file = self.context.source_file
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index f96b0dc377..1f1adf4fbd 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -94,7 +94,7 @@ class TransformerTest(test.TestCase):
inner_function, lambda_node),
anno.getanno(lambda_expr, 'enclosing_entities'))
- def test_statement_info_stack(self):
+ def test_local_scope_info_stack(self):
class TestTransformer(transformer.Base):
@@ -142,7 +142,7 @@ class TransformerTest(test.TestCase):
self.assertFalse(anno.hasanno(while_node, 'string'))
self.assertEqual('1', anno.getanno(while_node, 'test'))
- def test_statement_info_stack_checks_integrity(self):
+ def test_local_scope_info_stack_checks_integrity(self):
class TestTransformer(transformer.Base):
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index d65c990c87..b6dae3cc1f 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -96,6 +96,7 @@ py_test(
name = "batch_ops_test",
size = "small",
srcs = ["python/ops/batch_ops_test.py"],
+ shard_count = 5,
srcs_version = "PY2AND3",
tags = [
"manual",
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 04e32267cc..401bec84a2 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -43,47 +43,60 @@ namespace {
const int32 DUMMY_FEATURE_DIMENSION = -1;
} // namespace
-class BaseBuildSplitOp : public OpKernel {
+class SplitBuilderState {
public:
- explicit BaseBuildSplitOp(OpKernelConstruction* const context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id",
- &feature_column_group_id_));
+ explicit SplitBuilderState(OpKernelContext* const context) {
+ const Tensor* l1_regularization_t;
OP_REQUIRES_OK(context,
- context->GetAttr("l1_regularization", &l1_regularization_));
+ context->input("l1_regularization", &l1_regularization_t));
+ const Tensor* l2_regularization_t;
OP_REQUIRES_OK(context,
- context->GetAttr("l2_regularization", &l2_regularization_));
- OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization",
- &tree_complexity_regularization_));
+ context->input("l2_regularization", &l2_regularization_t));
+ const Tensor* tree_complexity_regularization_t;
+ OP_REQUIRES_OK(context, context->input("tree_complexity_regularization",
+ &tree_complexity_regularization_t));
+ const Tensor* min_node_weight_t;
OP_REQUIRES_OK(context,
- context->GetAttr("min_node_weight", &min_node_weight_));
+ context->input("min_node_weight", &min_node_weight_t));
- int strategy;
- OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy));
+ const Tensor* feature_column_group_id_t;
+ OP_REQUIRES_OK(context, context->input("feature_column_group_id",
+ &feature_column_group_id_t));
+
+ const Tensor* multiclass_strategy_t;
+ OP_REQUIRES_OK(
+ context, context->input("multiclass_strategy", &multiclass_strategy_t));
+ int strategy = multiclass_strategy_t->scalar<int32>()();
OP_REQUIRES(
context,
boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid(
strategy),
errors::InvalidArgument("Wrong multiclass strategy passed."));
- multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
- }
- NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
- return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
- multiclass_strategy_, grad_stats);
- }
+ multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
- void ReadClassId(OpKernelContext* const context, int32* class_id) {
const Tensor* class_id_t;
OP_REQUIRES_OK(context, context->input("class_id", &class_id_t));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()),
errors::InvalidArgument("class_id must be a scalar."));
- *class_id = class_id_t->scalar<int32>()();
+ class_id_ = class_id_t->scalar<int32>()();
+
+ l1_regularization_ = l1_regularization_t->scalar<float>()();
+ l2_regularization_ = l2_regularization_t->scalar<float>()();
+ tree_complexity_regularization_ =
+ tree_complexity_regularization_t->scalar<float>()();
+ min_node_weight_ = min_node_weight_t->scalar<float>()();
+ feature_column_group_id_ = feature_column_group_id_t->scalar<int32>()();
+ }
+
+ NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
+ return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
+ multiclass_strategy_, grad_stats);
}
- void FillLeaf(const int class_id, const NodeStats& best_node_stats,
+ void FillLeaf(const NodeStats& best_node_stats,
boosted_trees::trees::Leaf* leaf) const {
- if (class_id == -1) {
+ if (class_id_ == -1) {
// This would be the case either for TREE_PER_CLASS with only 2 classes,
// or for other multiclass strategies.
for (float f : best_node_stats.weight_contribution) {
@@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel {
CHECK(best_node_stats.weight_contribution.size() == 1)
<< "Weight contribution size = "
<< best_node_stats.weight_contribution.size();
- leaf->mutable_sparse_vector()->add_index(class_id);
+ leaf->mutable_sparse_vector()->add_index(class_id_);
leaf->mutable_sparse_vector()->add_value(
best_node_stats.weight_contribution[0]);
}
}
- protected:
+ int32 feature_column_group_id() { return feature_column_group_id_; }
+ float tree_complexity_regularization() {
+ return tree_complexity_regularization_;
+ }
+
+ private:
LearnerConfig_MultiClassStrategy multiclass_strategy_;
- int32 feature_column_group_id_;
float l1_regularization_;
float l2_regularization_;
- float min_node_weight_;
float tree_complexity_regularization_;
+ float min_node_weight_;
+ int32 class_id_;
+ int32 feature_column_group_id_;
};
-class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildDenseInequalitySplitsOp : public OpKernel {
public:
explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {}
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
@@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
-
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
- NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
+ NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+ NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
- dense_split->set_feature_column(feature_column_group_id_);
+ dense_split->set_feature_column(state.feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
}
}
@@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
BuildDenseInequalitySplitsOp);
-class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildSparseInequalitySplitsOp : public OpKernel {
public:
explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("bias_feature_id", &bias_feature_id_));
- }
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
@@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
+ const Tensor* bias_feature_id_t;
+ OP_REQUIRES_OK(context,
+ context->input("bias_feature_id", &bias_feature_id_t));
+ int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
// For each partition (tree node), store starting index for each dimension.
PartitionAndDimensionBoundaries partition_boundaries;
@@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
// For each tree node that needs to be split.
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
const auto& dimension_boundaries =
@@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
OP_REQUIRES(
context,
- bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_,
+ bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id,
errors::InvalidArgument("Bias feature ID missing."));
// Dimension for bias feature is always 0
@@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
GradientStats root_gradient_stats(*gradients_t, *hessians_t,
bias_start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
// Iterate through dimensions.
for (int j = 0; j < dimension_boundaries.size() - 1; ++j) {
@@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
<< bucket_ids_and_dimensions(start_index, 1) << " and for "
<< bucket_ids_and_dimensions(end_index - 1, 0) << " "
<< bucket_ids_and_dimensions(end_index - 1, 1);
- if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) {
+ if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
<< "Dimension of bias feature should be 0";
@@ -447,10 +464,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
present_gradient_stats - left_gradient_stats;
{
- NodeStats left_stats_default_left =
- ComputeNodeStats(root_gradient_stats - right_gradient_stats);
+ NodeStats left_stats_default_left = state.ComputeNodeStats(
+ root_gradient_stats - right_gradient_stats);
NodeStats right_stats_default_left =
- ComputeNodeStats(right_gradient_stats);
+ state.ComputeNodeStats(right_gradient_stats);
if (left_stats_default_left.gain + right_stats_default_left.gain >
best_gain) {
best_gain =
@@ -466,9 +483,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
// enough missing examples.
if (!fixed_default_direction) {
NodeStats left_stats_default_right =
- ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats_default_right =
- ComputeNodeStats(root_gradient_stats - left_gradient_stats);
+ state.ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats_default_right = state.ComputeNodeStats(
+ root_gradient_stats - left_gradient_stats);
if (left_stats_default_right.gain + right_stats_default_right.gain >
best_gain) {
best_gain = left_stats_default_right.gain +
@@ -494,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
->mutable_sparse_float_binary_split_default_left()
->mutable_split();
}
- dense_split->set_feature_column(feature_column_group_id_);
+ dense_split->set_feature_column(state.feature_column_group_id());
// Set the feature index for the best feature column.
const int64 best_dimension_id =
bucket_ids_and_dimensions(best_element_idx, 1);
@@ -505,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(bias_start_index);
}
}
@@ -526,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
// For each partition, store start indices of feature column dimensions.
typedef std::vector<std::vector<DimensionBoundary>>
PartitionAndDimensionBoundaries;
-
- int64 bias_feature_id_;
};
REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU),
BuildSparseInequalitySplitsOp);
-class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
+class BuildCategoricalEqualitySplitsOp : public OpKernel {
public:
explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("bias_feature_id", &bias_feature_id_));
- }
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
@@ -561,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
+ const Tensor* bias_feature_id_t;
+ OP_REQUIRES_OK(context,
+ context->input("bias_feature_id", &bias_feature_id_t));
+ int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
@@ -605,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1];
// First feature ID in each partition should be the bias feature.
- OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_,
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -625,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -637,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(feature_column_group_id_);
+ equality_split->set_feature_column(state.feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
}
}
-
- private:
- int64 bias_feature_id_;
};
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index f06b73c00d..409a2d8f46 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -64,6 +64,8 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
+from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
@@ -72,9 +74,11 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+
_BIAS_FEATURE_ID = -1
# Pattern to remove all non alpha numeric from a string.
_PATTERN = re.compile(r"[\W_]+")
@@ -130,11 +134,14 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
gradient_shape,
hessian_shape,
name="StatsAccumulator/{}".format(self._name))
- self._quantile_accumulator = quantile_ops.QuantileAccumulator(
- init_stamp_token,
- epsilon=epsilon,
- num_quantiles=num_quantiles,
- name="QuantileAccumulator/{}".format(self._name))
+ # Allocate both stats accumulator and quantile accumulator on the same
+ # device so that we can build splits with fewer RPCs.
+ with ops.colocate_with(self._stats_accumulator.resource()):
+ self._quantile_accumulator = quantile_ops.QuantileAccumulator(
+ init_stamp_token,
+ epsilon=epsilon,
+ num_quantiles=num_quantiles,
+ name="QuantileAccumulator/{}".format(self._name))
class DenseSplitHandler(InequalitySplitHandler):
@@ -236,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler):
def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- self._quantile_accumulator.get_buckets(stamp_token))
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = self._quantile_accumulator.flush(
- stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
- # Get the aggregated gradients and hessians per <partition_id, feature_id>
- # pair.
- # In order to distribute the computation on all the PSs we use the PS that
- # had the stats accumulator on.
- with ops.device(None):
- with ops.device(self._stats_accumulator.resource().device):
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- are_splits_ready = control_flow_ops.with_dependencies(
- [flush_quantiles, partition_ids], are_splits_ready)
-
- partition_ids, gains, split_infos = (
- split_handler_ops.build_dense_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=self._feature_column_group_id,
- l1_regularization=self._l1_regularization,
- l2_regularization=self._l2_regularization,
- tree_complexity_regularization=self.
- _tree_complexity_regularization,
- min_node_weight=self._min_node_weight,
- multiclass_strategy=self._multiclass_strategy))
- return (are_splits_ready, partition_ids, gains, split_infos)
+ if (self._gradient_shape == tensor_shape.scalar() and
+ self._hessian_shape == tensor_shape.scalar()):
+ handler = make_dense_split_scalar
+ else:
+ handler = make_dense_split_tensor
+
+ are_splits_ready, partition_ids, gains, split_infos = (
+ handler(self._quantile_accumulator.resource(),
+ self._stats_accumulator.resource(), stamp_token,
+ next_stamp_token, self._multiclass_strategy, class_id,
+ self._feature_column_group_id, self._l1_regularization,
+ self._l2_regularization, self._tree_complexity_regularization,
+ self._min_node_weight))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a dense feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_dense_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
class SparseSplitHandler(InequalitySplitHandler):
@@ -327,9 +363,6 @@ class SparseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
init_stamp_token=init_stamp_token,
name=name)
- # Register sparse_make_stats_update function as an Op to the graph.
- g = ops.get_default_graph()
- sparse_make_stats_update.add_to_graph(g)
self._sparse_float_column = sparse_float_column
def scheduled_reads(self):
@@ -361,8 +394,8 @@ class SparseSplitHandler(InequalitySplitHandler):
are_buckets_ready, buckets = scheduled_reads[0]
with ops.name_scope(self._name, "SparseSplitHandler"):
(quantile_indices, quantile_values, quantile_shapes, quantile_weights,
- example_partition_ids,
- feature_ids, gradients, hessians) = sparse_make_stats_update(
+ example_partition_ids, feature_ids, gradients,
+ hessians) = sparse_make_stats_update(
is_active, are_buckets_ready, self._sparse_float_column.indices,
self._sparse_float_column.values,
self._sparse_float_column.dense_shape, buckets,
@@ -379,42 +412,115 @@ class SparseSplitHandler(InequalitySplitHandler):
def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- self._quantile_accumulator.get_buckets(stamp_token))
-
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = self._quantile_accumulator.flush(
- stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
- with ops.device(None):
- with ops.device(self._stats_accumulator.resource().device):
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- are_splits_ready = control_flow_ops.with_dependencies(
- [flush_quantiles, partition_ids], are_splits_ready)
- partition_ids, gains, split_infos = (
- split_handler_ops.build_sparse_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=self._feature_column_group_id,
- l1_regularization=self._l1_regularization,
- l2_regularization=self._l2_regularization,
- tree_complexity_regularization=self.
- _tree_complexity_regularization,
- min_node_weight=self._min_node_weight,
- bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
- return (are_splits_ready, partition_ids, gains, split_infos)
+ if (self._gradient_shape == tensor_shape.scalar() and
+ self._hessian_shape == tensor_shape.scalar()):
+ handler = make_sparse_split_scalar
+ else:
+ handler = make_sparse_split_tensor
+
+ are_splits_ready, partition_ids, gains, split_infos = (
+ handler(self._quantile_accumulator.resource(),
+ self._stats_accumulator.resource(), stamp_token,
+ next_stamp_token, self._multiclass_strategy, class_id,
+ self._feature_column_group_id, self._l1_regularization,
+ self._l2_regularization, self._tree_complexity_regularization,
+ self._min_node_weight))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a sparse feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_sparse_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ bias_feature_id=_BIAS_FEATURE_ID,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_make_split(func, is_multi_dimentional):
+ """Builds a specialized version of the function."""
+
+ @function.Defun(
+ dtypes.resource,
+ dtypes.resource,
+ dtypes.int64,
+ dtypes.int64,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ noinline=True)
+ def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight):
+ """Function that builds splits for a sparse feature column."""
+ return func(
+ quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional)
+
+ return f
+
+make_dense_split_scalar = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=False)
+make_dense_split_tensor = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=True)
+
+make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=True)
@function.Defun(
@@ -540,8 +646,9 @@ def sparse_make_stats_update(
empty_float = constant_op.constant([], dtype=dtypes.float32)
handler_not_active = (constant_op.constant(
- [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant(
- [0, 1], dtype=dtypes.int64), empty_float)
+ [], dtype=dtypes.int64, shape=[0, 2]), empty_float,
+ constant_op.constant([0, 1], dtype=dtypes.int64),
+ empty_float)
handler_active = (sparse_column_indices, sparse_column_values,
sparse_column_shape, weights)
quantile_indices, quantile_values, quantile_shape, quantile_weights = (
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 54d03018d9..2f2c230211 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import split_info_pb2
@@ -65,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.scalar()
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -199,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.TensorShape([2, 2])
split_handler = ordinal_split_handler.DenseSplitHandler(
- l1_regularization=0,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=3,
feature_column_group_id=0,
@@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -285,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.TensorShape([2])
split_handler = ordinal_split_handler.DenseSplitHandler(
- l1_regularization=0,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=3,
feature_column_group_id=0,
@@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -369,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, False]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([False, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -443,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
+ l2_regularization=1.,
tree_complexity_regularization=0.5,
- min_node_weight=0,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -576,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
+ l2_regularization=1.,
tree_complexity_regularization=0.5,
min_node_weight=1.5,
epsilon=0.001,
@@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
-
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
@@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
@@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
@@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, False]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
@@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([False, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
@@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
@@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
@@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 5d0ebbf73c..ca5c7f3d8c 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -23,12 +23,6 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("BuildDenseInequalitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("bucket_ids: int64")
@@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the
)doc");
REGISTER_OP("BuildSparseInequalitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("bias_feature_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("bucket_ids: int64")
@@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("bias_feature_id: int64")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the
)doc");
REGISTER_OP("BuildCategoricalEqualitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("bias_feature_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("feature_ids: int64")
.Input("gradients: float32")
.Input("hessians: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("bias_feature_id: int64")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs.
feature_ids: A rank 2 tensor of feature IDs and dimensions.
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the
)doc");
} // namespace tensorflow
- // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
index 7a5f329b7a..843420968a 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import abc
import collections
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -60,6 +62,7 @@ def _move_tensors(tensors, device):
"""Moves a list of tensors to a device by concatenating/splitting them."""
# Reset the device setting to avoid weird interactions with device merging
# logic.
+ zero = constant_op.constant(0, dtype=dtypes.int32)
with ops.device(None):
if all(tensor.shape == tensor_shape.scalar() for tensor in tensors):
with ops.device(tensors[0].device):
@@ -68,12 +71,11 @@ def _move_tensors(tensors, device):
return array_ops.unstack(values)
else:
with ops.device(tensors[0].device):
- sizes = array_ops.stack(
- [array_ops.shape(tensor)[0] for tensor in tensors])
- values = array_ops.concat(tensors, axis=0)
+ sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0]
+ values = array_ops.concat(tensors, axis=zero)
with ops.device(device):
sizes = array_ops.unstack(sizes)
- return list(array_ops.split(values, sizes, axis=0))
+ return list(array_ops.split(values, sizes, axis=zero))
def _scheduled_stamp_resource_op_runner(batch, stamp):
diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
index 50cc00afdc..19b6b3296d 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
@@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
stamp_token=stamp_token,
next_stamp_token=next_stamp_token)
return result
+
+ def resource(self):
+ return self._quantile_accumulator_handle
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index e53d86ec61..5dd2e0c7f2 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -180,8 +180,7 @@ def extract_features(features, feature_columns, use_core_columns):
elif isinstance(fc, feature_column_lib._EmbeddingColumn):
# pylint: enable=protected-access
transformed_features[fc.name] = fc_core.input_layer(
- features, [fc],
- weight_collections=[scope])
+ features, [fc], weight_collections=[scope])
else:
result = feature_column_ops.transform_features(features, [fc])
if len(result) > 1:
@@ -334,10 +333,12 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._attempted_trees = variables.Variable(
- initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+ initial_value=array_ops.zeros([], dtypes.int64),
+ trainable=False,
name="attempted_trees")
self._finalized_trees = variables.Variable(
- initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+ initial_value=array_ops.zeros([], dtypes.int64),
+ trainable=False,
name="finalized_trees")
if not features:
raise ValueError("Features dictionary must be specified.")
@@ -354,9 +355,10 @@ class GradientBoostedDecisionTreeModel(object):
self._sparse_int_indices = sparse_int_indices
self._sparse_int_values = sparse_int_values
self._sparse_int_shapes = sparse_int_shapes
- self._reduce_dim = (self._learner_config.multi_class_strategy ==
- learner_pb2.LearnerConfig.TREE_PER_CLASS and
- learner_config.num_classes == 2)
+ self._reduce_dim = (
+ self._learner_config.multi_class_strategy ==
+ learner_pb2.LearnerConfig.TREE_PER_CLASS and
+ learner_config.num_classes == 2)
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
"""Runs prediction and returns a dictionary of the prediction results.
@@ -374,8 +376,8 @@ class GradientBoostedDecisionTreeModel(object):
ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle,
ensemble_stamp)
num_handlers = (
- len(self._dense_floats) + len(self._sparse_float_shapes) +
- len(self._sparse_int_shapes))
+ len(self._dense_floats) + len(self._sparse_float_shapes) + len(
+ self._sparse_int_shapes))
# Used during feature selection.
used_handlers = model_ops.tree_ensemble_used_handlers(
ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)
@@ -432,8 +434,9 @@ class GradientBoostedDecisionTreeModel(object):
# Use the current ensemble to predict on the current batch of input.
# For faster prediction we check if the inputs are on the same device
# as the model. If not, we create a copy of the model on the worker.
- input_deps = (self._dense_floats + self._sparse_float_indices +
- self._sparse_int_indices)
+ input_deps = (
+ self._dense_floats + self._sparse_float_indices +
+ self._sparse_int_indices)
if not input_deps:
raise ValueError("No input tensors for prediction.")
@@ -457,8 +460,8 @@ class GradientBoostedDecisionTreeModel(object):
# Determine whether the local ensemble is stale and update it if needed.
def _refresh_local_ensemble_fn():
- # Serialize the model from parameter server after reading all inputs.
- with ops.control_dependencies(input_deps):
+ # Serialize the model from parameter server after reading the inputs.
+ with ops.control_dependencies([input_deps[0]]):
(ensemble_stamp, serialized_model) = (
model_ops.tree_ensemble_serialize(self._ensemble_handle))
@@ -500,8 +503,9 @@ class GradientBoostedDecisionTreeModel(object):
ValueError: if inputs are not valid.
"""
# Get the worker device from input dependencies.
- input_deps = (self._dense_floats + self._sparse_float_indices +
- self._sparse_int_indices)
+ input_deps = (
+ self._dense_floats + self._sparse_float_indices +
+ self._sparse_int_indices)
worker_device = input_deps[0].device
# Get tensors relevant for training and form the loss.
@@ -517,7 +521,7 @@ class GradientBoostedDecisionTreeModel(object):
aggregation_method=None)[0]
strategy = self._learner_config.multi_class_strategy
- class_id = -1
+ class_id = constant_op.constant(-1, dtype=dtypes.int32)
# Handle different multiclass strategies.
if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
# We build one vs rest trees.
@@ -571,31 +575,39 @@ class GradientBoostedDecisionTreeModel(object):
# Get the weights for each example for quantiles calculation,
weights = self._get_weights(hessian_shape, squeezed_hessians)
- regularization_config = self._learner_config.regularization
- min_node_weight = self._learner_config.constraints.min_node_weight
# Create all handlers ensuring resources are evenly allocated across PS.
fc_name_idx = 0
handlers = []
init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
+ l1_regularization = constant_op.constant(
+ self._learner_config.regularization.l1, dtypes.float32)
+ l2_regularization = constant_op.constant(
+ self._learner_config.regularization.l2, dtypes.float32)
+ tree_complexity_regularization = constant_op.constant(
+ self._learner_config.regularization.tree_complexity, dtypes.float32)
+ min_node_weight = constant_op.constant(
+ self._learner_config.constraints.min_node_weight, dtypes.float32)
+ epsilon = 0.01
+ num_quantiles = 100
+ strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
for dense_float_column_idx in range(len(self._dense_floats)):
fc_name = self._fc_names[fc_name_idx]
handlers.append(
ordinal_split_handler.DenseSplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=dense_float_column_idx,
- epsilon=0.01,
- num_quantiles=100,
+ epsilon=epsilon,
+ num_quantiles=num_quantiles,
dense_float_column=self._dense_floats[dense_float_column_idx],
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
@@ -604,14 +616,13 @@ class GradientBoostedDecisionTreeModel(object):
fc_name = self._fc_names[fc_name_idx]
handlers.append(
ordinal_split_handler.SparseSplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=sparse_float_column_idx,
- epsilon=0.01,
- num_quantiles=100,
+ epsilon=epsilon,
+ num_quantiles=num_quantiles,
sparse_float_column=sparse_tensor.SparseTensor(
self._sparse_float_indices[sparse_float_column_idx],
self._sparse_float_values[sparse_float_column_idx],
@@ -619,7 +630,7 @@ class GradientBoostedDecisionTreeModel(object):
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
@@ -628,10 +639,9 @@ class GradientBoostedDecisionTreeModel(object):
fc_name = self._fc_names[fc_name_idx]
handlers.append(
categorical_split_handler.EqualitySplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=sparse_int_column_idx,
sparse_int_column=sparse_tensor.SparseTensor(
@@ -641,7 +651,7 @@ class GradientBoostedDecisionTreeModel(object):
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
@@ -694,11 +704,11 @@ class GradientBoostedDecisionTreeModel(object):
name="continue_centering",
trainable=False)
stats_update_ops.append(
- control_flow_ops.cond(continue_centering,
- self._make_update_bias_stats_fn(
- ensemble_stamp, predictions, gradients,
- bias_stats_accumulator),
- control_flow_ops.no_op))
+ control_flow_ops.cond(
+ continue_centering,
+ self._make_update_bias_stats_fn(ensemble_stamp, predictions,
+ gradients, bias_stats_accumulator),
+ control_flow_ops.no_op))
# Update handler stats.
handler_reads = collections.OrderedDict()
@@ -720,8 +730,8 @@ class GradientBoostedDecisionTreeModel(object):
shape=[len(handlers)], seed=[seed + 1, 1])
active_handlers = array_ops.stack(
[active_handlers_current_layer, active_handlers_next_layer], axis=1)
- active_handlers = (active_handlers <
- self._learner_config.feature_fraction_per_level)
+ active_handlers = (
+ active_handlers < self._learner_config.feature_fraction_per_level)
elif subsampling_type == "feature_fraction_per_tree":
seed = predictions_dict[NUM_TREES_ATTEMPTED]
active_handlers_current_layer = stateless.stateless_random_uniform(
@@ -729,9 +739,12 @@ class GradientBoostedDecisionTreeModel(object):
active_handlers_current_layer = (
active_handlers_current_layer <
self._learner_config.feature_fraction_per_tree)
- active_handlers = array_ops.stack([
- active_handlers_current_layer,
- array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1)
+ active_handlers = array_ops.stack(
+ [
+ active_handlers_current_layer,
+ array_ops.ones([len(handlers)], dtype=dtypes.bool)
+ ],
+ axis=1)
else:
active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)
@@ -760,6 +773,7 @@ class GradientBoostedDecisionTreeModel(object):
empty_hessians = constant_op.constant(
[], dtype=dtypes.float32, shape=empty_hess_shape)
+ active_handlers = array_ops.unstack(active_handlers, axis=0)
for handler_idx in range(len(handlers)):
handler = handlers[handler_idx]
is_active = active_handlers[handler_idx]
@@ -901,7 +915,6 @@ class GradientBoostedDecisionTreeModel(object):
"DecisionTreeEnsembleResourceHandleOp",
"StatsAccumulatorScalarResourceHandleOp",
"StatsAccumulatorTensorResourceHandleOp",
- "QuantileStreamResourceHandleOp",
]
ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks)
return device_setter.replica_device_setter(
@@ -971,7 +984,7 @@ class GradientBoostedDecisionTreeModel(object):
# This is a workaround for the slowness of graph building in tf.cond.
# See (b/36554864).
split_sizes = array_ops.reshape(
- array_ops.shape_n(partition_ids_list), [-1])
+ array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
partition_ids = array_ops.concat(partition_ids_list, axis=0)
gains = array_ops.concat(gains_list, axis=0)
split_infos = array_ops.concat(split_info_list, axis=0)
@@ -1036,8 +1049,11 @@ class GradientBoostedDecisionTreeModel(object):
# Update ensemble.
update_ops = [are_all_splits_ready]
- update_model = control_flow_ops.cond(continue_centering, _center_bias_fn,
- _grow_ensemble_fn)
+ if self._center_bias:
+ update_model = control_flow_ops.cond(continue_centering,
+ _center_bias_fn, _grow_ensemble_fn)
+ else:
+ update_model = _grow_ensemble_fn()
update_ops.append(update_model)
# Update ensemble stats.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index f9c22283b7..289fb195db 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.feature_column import feature_column_lib as core_feature_
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn
-
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
@@ -97,8 +96,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
array_ops.zeros([2], dtypes.int64))
features["sparse_int"] = sparse_tensor.SparseTensor(
array_ops.zeros([2, 2], dtypes.int64),
- array_ops.zeros([2], dtypes.int64),
- array_ops.zeros([2], dtypes.int64))
+ array_ops.zeros([2], dtypes.int64), array_ops.zeros([2],
+ dtypes.int64))
(fc_names, dense_floats, sparse_float_indices, sparse_float_values,
sparse_float_shapes, sparse_int_indices, sparse_int_values,
sparse_int_shapes) = (
@@ -139,8 +138,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
array_ops.zeros([2], dtypes.int64))
features["sparse_categorical"] = sparse_tensor.SparseTensor(
array_ops.zeros([2, 2], dtypes.int64),
- array_ops.zeros(
- [2], dtypes.string), array_ops.zeros([2], dtypes.int64))
+ array_ops.zeros([2], dtypes.string), array_ops.zeros([2],
+ dtypes.int64))
feature_columns = set()
feature_columns.add(layers.real_valued_column("dense_float"))
feature_columns.add(
@@ -235,7 +234,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@@ -316,6 +316,113 @@ class GbdtTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, output.trees[0])
+ def testTrainFnChiefSparseAndDense(self):
+ """Tests the train function with sparse and dense features."""
+ with self.test_session() as sess:
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.constraints.min_node_weight = 0
+ features = {}
+ features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
+ features["sparse_float"] = sparse_tensor.SparseTensor(
+ array_ops.zeros([2, 2], dtypes.int64),
+ array_ops.zeros([2], dtypes.float32),
+ array_ops.constant([4, 1], dtypes.int64))
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions = array_ops.constant(
+ [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
+ partition_ids = array_ops.zeros([4], dtypes.int32)
+ ensemble_stamp = variables.Variable(
+ initial_value=0,
+ name="ensemble_stamp",
+ trainable=False,
+ dtype=dtypes.int64)
+
+ predictions_dict = {
+ "predictions": predictions,
+ "predictions_no_dropout": predictions,
+ "partition_ids": partition_ids,
+ "ensemble_stamp": ensemble_stamp,
+ "num_trees": 12,
+ }
+
+ labels = array_ops.ones([4, 1], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+ # Create train op.
+ train_op = gbdt_model.train(
+ loss=math_ops.reduce_mean(
+ _squared_loss(labels, weights, predictions)),
+ predictions_dict=predictions_dict,
+ labels=labels)
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect no splits to be chosen because the quantile
+ # buckets will not be ready.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 1)
+
+ # Update the stamp to be able to run a second time.
+ sess.run([ensemble_stamp.assign_add(1)])
+
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [0.1])
+ self.assertEquals(stamp_token.eval(), 2)
+ expected_tree = """
+ nodes {
+ sparse_float_binary_split_default_right {
+ split{
+ left_id: 1
+ right_id: 2
+ }
+ }
+ node_metadata {
+ gain: 1.125
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -0.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
def testTrainFnChiefScalingNumberOfExamples(self):
"""Tests the train function running on chief without bias centering."""
with self.test_session() as sess:
@@ -339,7 +446,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=num_examples_fn,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@@ -442,7 +550,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@@ -513,7 +622,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@@ -576,7 +686,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
@@ -622,7 +733,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
with self.test_session() as sess:
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
- text_format.Merge("""
+ text_format.Merge(
+ """
trees {
nodes {
leaf {
@@ -659,14 +771,15 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=1, features=features)
+ logits_dimension=1,
+ features=features)
# Create predict op.
mode = model_fn.ModeKeys.EVAL
predictions_dict = sess.run(gbdt_model.predict(mode))
self.assertEquals(predictions_dict["ensemble_stamp"], 3)
- self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25],
- [0.25], [0.25]])
+ self.assertAllClose(predictions_dict["predictions"],
+ [[0.25], [0.25], [0.25], [0.25]])
self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
def testTrainFnMulticlassFullHessian(self):
@@ -698,7 +811,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=5, features=features)
+ logits_dimension=5,
+ features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@@ -801,7 +915,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=5, features=features)
+ logits_dimension=5,
+ features=features)
predictions = array_ops.constant(
[[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0],
@@ -893,8 +1008,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
learner_config.constraints.max_tree_depth = 1
learner_config.constraints.min_node_weight = 0
features = {
- "dense_float": array_ops.constant(
- [[1.0], [1.5], [2.0]], dtypes.float32),
+ "dense_float":
+ array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32),
}
gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
@@ -904,7 +1019,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
ensemble_handle=ensemble_handle,
examples_per_layer=1,
learner_config=learner_config,
- logits_dimension=5, features=features)
+ logits_dimension=5,
+ features=features)
batch_size = 3
predictions = array_ops.constant(
@@ -986,7 +1102,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
self.assertAllClose(
0.893284678459,
output.trees[0].nodes[2].leaf.sparse_vector.value[0],
- atol=1e-4, rtol=1e-4)
+ atol=1e-4,
+ rtol=1e-4)
def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self):
"""Tests the train function running on chief with feature selection."""
@@ -1230,9 +1347,9 @@ class GbdtTest(test_util.TensorFlowTestCase):
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
- _set_float_split(tree.nodes.add()
- .sparse_float_binary_split_default_right.split, 2, 4.0,
- 1, 2)
+ _set_float_split(
+ tree.nodes.add().sparse_float_binary_split_default_right.split, 2,
+ 4.0, 1, 2)
_append_to_leaf(tree.nodes.add().leaf, 0, 0.5)
_append_to_leaf(tree.nodes.add().leaf, 1, 1.2)
tree_ensemble_config.tree_weights.append(1.0)
@@ -1241,7 +1358,8 @@ class GbdtTest(test_util.TensorFlowTestCase):
metadata.num_layers_grown = 1
tree_ensemble_config = tree_ensemble_config.SerializeToString()
ensemble_handle = model_ops.tree_ensemble_variable(
- stamp_token=0, tree_ensemble_config=tree_ensemble_config,
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config,
name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
@@ -1333,5 +1451,6 @@ class GbdtTest(test_util.TensorFlowTestCase):
self.assertEquals(output.growing_metadata.num_layers_attempted, 2)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index af8df72618..8ae493ba99 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -18,11 +18,15 @@ Visualization and inspection:
@@dot_graph_from_checkpoint
@@object_metadata
-Creating and managing dependencies:
+Managing dependencies:
@@Checkpointable
@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
+
+Checkpointable data structures:
+@@List
+@@Mapping
@@UniqueNameTracker
"""
@@ -36,8 +40,11 @@ from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkp
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import NoDependency
+from tensorflow.python.training.checkpointable.data_structures import List
+from tensorflow.python.training.checkpointable.data_structures import Mapping
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)
+
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index 53f4e97f99..7b200a29bf 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -11,6 +11,7 @@ py_library(
":containers",
":split_dependency",
":visualize",
+ "//tensorflow/python/training/checkpointable:data_structures",
],
)
@@ -19,7 +20,10 @@ py_library(
srcs = ["containers.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
- deps = ["//tensorflow/python/training/checkpointable:base"],
+ deps = [
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:data_structures",
+ ],
)
py_test(
@@ -30,8 +34,8 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
"@six_archive//:six",
],
)
@@ -44,6 +48,7 @@ py_library(
deps = [
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/training/checkpointable:base",
],
)
@@ -55,8 +60,9 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
@@ -67,6 +73,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
@@ -75,10 +83,13 @@ py_test(
srcs = ["visualize_test.py"],
deps = [
":visualize",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//tensorflow/python/training/checkpointable:util",
],
)
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
index 9807abae1f..4d3d531299 100644
--- a/tensorflow/contrib/checkpoint/python/containers.py
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -18,9 +18,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import data_structures
-class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
+class UniqueNameTracker(data_structures.CheckpointableDataStructure):
"""Adds dependencies on checkpointable objects with name hints.
Useful for creating dependencies with locally unique names.
@@ -41,6 +42,7 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
"""
def __init__(self):
+ super(UniqueNameTracker, self).__init__()
self._maybe_initialize_checkpointable()
self._name_counts = {}
@@ -74,4 +76,5 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
count += 1
candidate = _format_name(base_name, count)
self._name_counts[base_name] = count + 1
- return self._track_checkpointable(checkpointable, name=candidate)
+ self._track_value(checkpointable, name=candidate)
+ return checkpointable
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 851a800588..3717d7f583 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -22,6 +22,8 @@ import six
from tensorflow.contrib.checkpoint.python import containers
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import layers
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -95,5 +97,12 @@ class UniqueNameTrackerTests(test.TestCase):
dependency_names,
["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
+ @test_util.run_in_graph_and_eager_modes()
+ def testLayers(self):
+ tracker = containers.UniqueNameTracker()
+ tracker.track(layers.Dense(3), "dense")
+ tracker.layers[0](array_ops.zeros([1, 1]))
+ self.assertEqual(2, len(tracker.trainable_weights))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index a06bdf78fb..2e0a2fcef4 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -21,6 +21,7 @@ set(tf_c_srcs
"${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
+ "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc"
"${tensorflow_source_dir}/tensorflow/c/eager/tape.h"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc"
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index b47c32f1c4..dac84ccb0d 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -213,10 +213,6 @@ else()
list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude})
endif()
-file(GLOB tf_core_platform_exclude_srcs
- "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc")
-list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs})
-
list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs})
if(UNIX)
@@ -286,8 +282,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c
file(GLOB_RECURSE tf_core_framework_srcs
"${tensorflow_source_dir}/tensorflow/core/framework/*.h"
"${tensorflow_source_dir}/tensorflow/core/framework/*.cc"
- "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h"
- "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h"
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index a25aa85251..1af1ed08b5 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@assert_element_shape
@@batch_and_drop_remainder
@@bucket_by_sequence_length
+@@choose_from_datasets
@@dense_to_sparse_batch
@@enumerate_dataset
@@group_by_window
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index f5082228e8..c483a43769 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -208,6 +208,23 @@ py_test(
],
)
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
tf_py_test(
name = "get_single_element_test",
size = "small",
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 641a389c03..f9f11a1555 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -308,6 +308,23 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults=record_defaults,
)
+ def testMakeCsvDataset_fieldOrder(self):
+ data = [[
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
+ ]]
+ file_path = self.setup_files(data)
+
+ with ops.Graph().as_default() as g:
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ next_batch = ds.make_one_shot_iterator().get_next()
+
+ with self.test_session(graph=g) as sess:
+ result = list(sess.run(next_batch).values())
+
+ self.assertEqual(result, sorted(result))
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
new file mode 100644
index 0000000000..34b6a080c0
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -0,0 +1,167 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
+from tensorflow.python.platform import test
+
+
+class DirectedInterleaveDatasetTest(test.TestCase):
+
+ def testBasic(self):
+ selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
+ input_datasets = [
+ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
+ ]
+ dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(100):
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def _normalize(self, vec):
+ return vec / vec.sum()
+
+ def _chi2(self, expected, actual):
+ actual = np.asarray(actual)
+ expected = np.asarray(expected)
+ diff = actual - expected
+ chi2 = np.sum(diff * diff / expected, axis=0)
+ return chi2
+
+ def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
+ # Create a dataset that samples each integer in `[0, num_datasets)`
+ # with probability given by `weights[i]`.
+ dataset = interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(num_datasets)
+ ], weights)
+ dataset = dataset.take(num_samples)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ freqs = np.zeros([num_datasets])
+ for _ in range(num_samples):
+ freqs[sess.run(next_element)] += 1
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ return freqs
+
+ def testSampleFromDatasets(self):
+ random_seed.set_random_seed(1619)
+ num_samples = 5000
+ rand_probs = self._normalize(np.random.random_sample((15,)))
+
+ # Use chi-squared test to assert that the observed distribution matches the
+ # expected distribution. Based on the implementation in
+ # "tensorflow/python/kernel_tests/multinomial_op_test.py".
+ for probs in [[.85, .05, .1], rand_probs]:
+ probs = np.asarray(probs)
+ classes = len(probs)
+ freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ # Also check that `weights` as a dataset samples correctly.
+ probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
+ freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ def testSelectFromDatasets(self):
+ words = [b"foo", b"bar", b"baz"]
+ datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
+ choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
+ choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
+ dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for i in choice_array:
+ self.assertEqual(words[i], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ r"vector of length `len\(datasets\)`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[0.25, 0.25, 0.25, 0.25])
+
+ with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[1, 1])
+
+ with self.assertRaisesRegexp(TypeError, "must have the same type"):
+ interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(0.0)
+ ])
+
+ with self.assertRaisesRegexp(TypeError, "tf.int64"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
+
+ with self.assertRaisesRegexp(TypeError, "scalar"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
+
+
+class SampleFromDatasetsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, probs, num_samples):
+ dataset = interleave_ops.sample_from_datasets(
+ [
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(len(probs))
+ ],
+ probs,
+ seed=1813)
+ return dataset.take(num_samples)
+
+ def testSerializationCore(self):
+ self.run_core_tests(
+ lambda: self._build_dataset([0.5, 0.5], 100),
+ lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 43aa4b1bd0..bee561e3e2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -907,114 +906,5 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
-class DirectedInterleaveDatasetTest(test.TestCase):
-
- def testBasic(self):
- selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
- input_datasets = [
- dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
- ]
- dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
- input_datasets)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(100):
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def _normalize(self, vec):
- return vec / vec.sum()
-
- def _chi2(self, expected, actual):
- actual = np.asarray(actual)
- expected = np.asarray(expected)
- diff = actual - expected
- chi2 = np.sum(diff * diff / expected, axis=0)
- return chi2
-
- def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
- # Create a dataset that samples each integer in `[0, num_datasets)`
- # with probability given by `weights[i]`.
- dataset = interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(num_datasets)
- ], weights)
- dataset = dataset.take(num_samples)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.test_session() as sess:
- freqs = np.zeros([num_datasets])
- for _ in range(num_samples):
- freqs[sess.run(next_element)] += 1
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- return freqs
-
- def testSampleFromDatasets(self):
- random_seed.set_random_seed(1619)
- num_samples = 10000
- rand_probs = self._normalize(np.random.random_sample((15,)))
-
- # Use chi-squared test to assert that the observed distribution matches the
- # expected distribution. Based on the implementation in
- # "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs]:
- probs = np.asarray(probs)
- classes = len(probs)
- freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
-
- # Also check that `weights` as a dataset samples correctly.
- probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
- freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3)
-
- def testErrors(self):
- with self.assertRaisesRegexp(ValueError,
- r"vector of length `len\(datasets\)`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[0.25, 0.25, 0.25, 0.25])
-
- with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[1, 1])
-
- with self.assertRaisesRegexp(TypeError, "must have the same type"):
- interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(0.0)
- ])
-
-
-class SampleFromDatasetsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, probs, num_samples):
- dataset = interleave_ops.sample_from_datasets(
- [
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(len(probs))
- ],
- probs,
- seed=1813)
- return dataset.take(num_samples)
-
- def testSerializationCore(self):
- self.run_core_tests(
- lambda: self._build_dataset([0.5, 0.5], 100),
- lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 812a50ecbf..be66fbac50 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -27,6 +27,7 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -240,3 +241,47 @@ def sample_from_datasets(datasets, weights=None, seed=None):
(logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
return DirectedInterleaveDataset(selector_input, datasets)
+
+
+def choose_from_datasets(datasets, choice_dataset):
+ """Creates a dataset that deterministically chooses elements from `datasets`.
+
+ For example, given the following datasets:
+
+ ```python
+ datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
+ tf.data.Dataset.from_tensors("bar").repeat(),
+ tf.data.Dataset.from_tensors("baz").repeat()]
+
+ # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
+ choice_dataset = tf.data.Dataset.range(3).repeat(3)
+
+ result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)
+ ```
+
+ The elements of `result` will be:
+
+ ```
+ "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
+ ```
+
+ Args:
+ datasets: A list of @{tf.data.Dataset} objects with compatible structure.
+ choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between
+ `0` and `len(datasets) - 1`.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` according to the values
+ of `choice_dataset`.
+
+ Raises:
+ TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
+ type.
+ """
+ if not (choice_dataset.output_types == dtypes.int64
+ and choice_dataset.output_shapes.is_compatible_with(
+ tensor_shape.scalar())
+ and choice_dataset.output_classes == ops.Tensor):
+ raise TypeError("`choice_dataset` must be a dataset of scalar "
+ "`tf.int64` tensors.")
+ return DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 75c31a944a..f938153f5f 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import csv
import numpy as np
@@ -467,11 +468,11 @@ def make_csv_dataset(
Args:
*columns: list of `Tensor`s corresponding to one csv record.
Returns:
- A dictionary of feature names to values for that particular record. If
+ An OrderedDict of feature names to values for that particular record. If
label_name is provided, extracts the label feature to be returned as the
second element of the tuple.
"""
- features = dict(zip(column_names, columns))
+ features = collections.OrderedDict(zip(column_names, columns))
if label_name is not None:
label = features.pop(label_name)
return features, label
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 64a77bbed1..3118deaa47 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -445,6 +445,7 @@ py_library(
srcs = ["cross_tower_utils.py"],
srcs_version = "PY2AND3",
deps = [
+ ":values",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
@@ -452,6 +453,24 @@ py_library(
],
)
+cuda_py_test(
+ name = "cross_tower_utils_test",
+ srcs = ["cross_tower_utils_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":cross_tower_utils",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
py_library(
name = "cross_tower_ops",
srcs = ["cross_tower_ops.py"],
@@ -547,3 +566,21 @@ cuda_py_test(
"no_pip",
],
)
+
+cuda_py_test(
+ name = "keras_test",
+ srcs = ["keras_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:keras",
+ "//tensorflow/python/estimator:run_config",
+ "//tensorflow/python/keras",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "notsan",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 15935817b0..e400fa5be2 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -41,7 +41,10 @@ from __future__ import print_function
from collections import OrderedDict
import sys
+import types
+import unittest
from absl.testing import parameterized
+import six
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import one_device_strategy
@@ -67,8 +70,8 @@ def generate(combinations):
combinations: a list of dictionaries created using combine() and times().
Restrictions:
- -- there should always be a "mode" argument. Accepted values are "eager"
- and "graph".
+ -- the "mode" argument can be either "eager" or "graph". It's "graph" by
+ default.
-- arguments of the test method must match by name to get the corresponding
value of the combination. Tests must accept all arguments except the
"mode", "required_tpu" and "required_gpus".
@@ -83,14 +86,15 @@ def generate(combinations):
test will be skipped if the specified number of GPUs aren't available.
Returns:
- a decorator that will cause the test method to be run under the specified
- conditions.
+ a decorator that will cause the test method or the test class to be run
+ under the specified conditions.
Raises:
- ValueError - if "mode" argument wasn't either "eager" or "graph".
+ ValueError - if "mode" argument wasn't either "eager" or "graph" or if other
+ arguments were not accepted by the test method.
"""
- def decorator(test_function):
+ def decorator(test_method_or_class):
"""The decorator to be returned."""
# Generate good test names that can be used with --test_filter.
@@ -110,70 +114,91 @@ def generate(combinations):
list(combination.items()) + [("testcase_name",
"_test{}".format(name))]))
- @parameterized.named_parameters(*named_combinations)
- def decorated(self, **kwargs):
- """A wrapped test method that sets up `test_function`."""
- assert "mode" in kwargs
- mode = kwargs["mode"]
-
- distribution = kwargs.pop("distribution", None)
- required_tpu = kwargs.pop("required_tpu", False)
- required_gpus = kwargs.pop("required_gpus", None)
-
- if distribution:
- assert required_gpus is None, (
- "Do not use `required_gpus` and `distribution` together.")
- assert required_tpu is False, (
- "Do not use `required_tpu` and `distribution` together.")
- kwargs["distribution"] = distribution.strategy
- required_gpus = distribution.required_gpus
- required_tpu = distribution.required_tpu
-
- if required_tpu and not TPU_TEST:
- self.skipTest("Test requires a TPU, but it's not available.")
- if not required_tpu and TPU_TEST:
- self.skipTest("Test that doesn't require a TPU.")
-
- if not required_gpus:
- if GPU_TEST:
- self.skipTest("Test that doesn't require GPUs.")
- elif context.num_gpus() < required_gpus:
- self.skipTest(
- "{} GPUs are not available for this test. {} GPUs are available".
- format(required_gpus, context.num_gpus()))
-
- # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
- # that the user might have specified. `kwargs` still has `mode`, which
- # the test is allowed to accept or ignore.
- requested_arguments = tf_inspect.getfullargspec(test_function).args
- missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
- set(requested_arguments + ["mode"]))
- if missing_arguments:
- raise ValueError("The test is missing arguments {} .".format(
- missing_arguments))
-
- kwargs_to_pass = {}
- for arg in requested_arguments:
- if arg == "self":
- kwargs_to_pass[arg] = self
- else:
- kwargs_to_pass[arg] = kwargs[arg]
-
- if mode == "eager":
- with context.eager_mode(), ops.Graph().as_default():
- test_function(**kwargs_to_pass)
- elif mode == "graph":
- with context.graph_mode(), ops.Graph().as_default():
- test_function(**kwargs_to_pass)
- else:
- raise ValueError(
- "'mode' has to be either 'eager' or 'graph' and not {}".format(
- mode))
+ if isinstance(test_method_or_class, type):
+ class_object = test_method_or_class
+ class_object._test_method_ids = test_method_ids = {}
+ for name, test_method in six.iteritems(class_object.__dict__.copy()):
+ if (name.startswith(unittest.TestLoader.testMethodPrefix) and
+ isinstance(test_method, types.FunctionType)):
+ delattr(class_object, name)
+ methods = {}
+ parameterized._update_class_dict_for_param_test_case(
+ class_object.__name__, methods, test_method_ids, name,
+ parameterized._ParameterizedTestIter(
+ _augment_with_special_arguments(test_method),
+ named_combinations, parameterized._NAMED, name))
+ for method_name, method in six.iteritems(methods):
+ setattr(class_object, method_name, method)
+
+ return class_object
+ else:
+ test_method = _augment_with_special_arguments(test_method_or_class)
+ return parameterized.named_parameters(*named_combinations)(test_method)
- return decorated
return decorator
+def _augment_with_special_arguments(test_method):
+ def decorated(self, **kwargs):
+ """A wrapped test method that treats some arguments in a special way."""
+ mode = kwargs.pop("mode", "graph")
+
+ distribution = kwargs.pop("distribution", None)
+ required_tpu = kwargs.pop("required_tpu", False)
+ required_gpus = kwargs.pop("required_gpus", None)
+
+ if distribution:
+ assert required_gpus is None, (
+ "Do not use `required_gpus` and `distribution` together.")
+ assert required_tpu is False, (
+ "Do not use `required_tpu` and `distribution` together.")
+ kwargs["distribution"] = distribution.strategy
+ required_gpus = distribution.required_gpus
+ required_tpu = distribution.required_tpu
+
+ if required_tpu and not TPU_TEST:
+ self.skipTest("Test requires a TPU, but it's not available.")
+ if not required_tpu and TPU_TEST:
+ self.skipTest("Test that doesn't require a TPU.")
+
+ if not required_gpus:
+ if GPU_TEST:
+ self.skipTest("Test that doesn't require GPUs.")
+ elif context.num_gpus() < required_gpus:
+ self.skipTest(
+ "{} GPUs are not available for this test. {} GPUs are available".
+ format(required_gpus, context.num_gpus()))
+
+ # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu`
+ # that the user might have specified. `kwargs` still has `mode`, which
+ # the test is allowed to accept or ignore.
+ requested_arguments = tf_inspect.getfullargspec(test_method).args
+ missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
+ set(requested_arguments + ["mode"]))
+ if missing_arguments:
+ raise ValueError("The test is missing arguments {} .".format(
+ missing_arguments))
+
+ kwargs_to_pass = {}
+ for arg in requested_arguments:
+ if arg == "self":
+ kwargs_to_pass[arg] = self
+ else:
+ kwargs_to_pass[arg] = kwargs[arg]
+
+ if mode == "eager":
+ with ops.Graph().as_default(), context.eager_mode():
+ test_method(**kwargs_to_pass)
+ elif mode == "graph":
+ with ops.Graph().as_default(), context.graph_mode():
+ test_method(**kwargs_to_pass)
+ else:
+ raise ValueError(
+ "'mode' has to be either 'eager' or 'graph' and not {}".format(
+ mode))
+ return decorated
+
+
def combine(**kwargs):
"""Generate combinations based on its keyword arguments.
diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py
index 184bcf27e5..86aa48cea8 100644
--- a/tensorflow/contrib/distribute/python/combinations_test.py
+++ b/tensorflow/contrib/distribute/python/combinations_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from collections import OrderedDict
+from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.python.eager import test
@@ -120,5 +121,28 @@ class TestingCombinationsTest(test.TestCase):
_ = combinations.times(c1, c2)
+@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
+class CombineTheTestSuite(parameterized.TestCase):
+
+ def test_add_things(self, a, b, c):
+ self.assertLessEqual(3, a + b + c)
+ self.assertLessEqual(a + b + c, 5)
+
+ def test_add_things_one_more(self, a, b, c):
+ self.assertLessEqual(3, a + b + c)
+ self.assertLessEqual(a + b + c, 5)
+
+ def not_a_test(self, a=0, b=0, c=0):
+ del a, b, c
+ self.fail()
+
+ def _test_but_private(self, a=0, b=0, c=0):
+ del a, b, c
+ self.fail()
+
+ # Check that nothing funny happens to a non-callable that starts with "_test".
+ test_member = 0
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index c6a1bf6a9f..a411b880e8 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -77,12 +77,12 @@ def _all_devices_match(value_destination_pairs):
return True
-def _simple_broadcast(tensor, destinations):
+def _simple_broadcast(value, destinations):
index = {}
devices = _get_devices_from(destinations)
for d in devices:
- with ops.device(d):
- index[d] = array_ops.identity(tensor)
+ index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ value, d)
return value_lib.Mirrored(index)
@@ -98,7 +98,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
continue
count += len(v_list)
# Sum within each device before aggregating across devices.
- v = math_ops.add_n(v_list)
+ # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
+ v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+ v_list, math_ops.add_n)
else:
count += 1
all_values.append(v)
@@ -107,11 +109,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
with ops.device(reduce_to_device):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- if method_string == "sum":
- reduced = accumulation_fn(all_values)
- elif method_string == "mean":
- reduced = accumulation_fn(all_values) / count
- else:
+ reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
+ all_values, accumulation_fn)
+ if method_string == "mean":
+ reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
+ reduced, count)
+ elif method_string != "sum":
raise ValueError("`method_string` must be 'sum' or 'mean'")
return reduced
@@ -444,10 +447,18 @@ class AllReduceCrossTowerOps(CrossTowerOps):
super(AllReduceCrossTowerOps, self).__init__()
def _reduce(self, method_string, per_device_value, destinations):
+ contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+ per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
- and not context.executing_eagerly()):
+ and not context.executing_eagerly()
+ and not contains_indexed_slices):
return self._batch_all_reduce(method_string, [per_device_value])[0]
else:
+ if contains_indexed_slices:
+ logging.log_first_n(
+ logging.WARN,
+ "Efficient allreduce is not supported for IndexedSlices.", 10)
+
devices = _get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
@@ -455,14 +466,18 @@ class AllReduceCrossTowerOps(CrossTowerOps):
return self.broadcast(reduced, devices)
def _batch_reduce(self, method_string, value_destination_pairs):
- if (_all_devices_match(value_destination_pairs) and
- not context.executing_eagerly()):
+ all_devices_match = _all_devices_match(value_destination_pairs)
+ contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
+ value_destination_pairs)
+ if (all_devices_match and not context.executing_eagerly()
+ and not contains_indexed_slices):
return self._batch_all_reduce(method_string,
[v[0] for v in value_destination_pairs])
else:
- if not context.executing_eagerly():
+ if not all_devices_match:
logging.warning("Efficient batch_reduce is not supported if "
"destinations are different.")
+
return [
self._reduce(method_string, t, destinations=v)
for t, v in value_destination_pairs
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 7c7b087088..2a26632608 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
def _make_per_device(values, devices):
@@ -56,19 +57,46 @@ def _fake_mirrored(value, devices):
{d: v for d, v in zip(devices, [value] * len(devices))})
+def _make_indexed_slices(values, indices, dense_shape, device):
+ with ops.device(device):
+ tensor = ops.IndexedSlices(
+ values=constant_op.constant(values),
+ indices=constant_op.constant(indices),
+ dense_shape=constant_op.constant(dense_shape))
+ return tensor
+
+
+def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
+ return value_lib.Mirrored({
+ d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices
+ })
+
+
_cpu_device = "/device:CPU:0"
class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
- def _assert_value_equal(self, left, right):
+ def _assert_indexed_slices_equal(self, left, right):
+ self.assertIsInstance(left, ops.IndexedSlices)
+ self.assertIsInstance(right, ops.IndexedSlices)
+ self.assertEqual(device_util.resolve(left.device),
+ device_util.resolve(right.device))
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ def _assert_values_equal(self, left, right):
if isinstance(left, list):
for l, r in zip(left, right):
- self._assert_value_equal(l, r)
+ self._assert_values_equal(l, r)
else:
self.assertEqual(type(left), type(right))
self.assertEqual(left.devices, right.devices)
- if context.executing_eagerly():
+ if isinstance(list(left._index.values())[0], ops.IndexedSlices):
+ for (d, v) in left._index.iteritems():
+ self._assert_indexed_slices_equal(v, right._index[d])
+ elif context.executing_eagerly():
self.assertEqual([v.numpy() for v in left._index.values()],
list(right._index.values()))
else:
@@ -143,29 +171,29 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
# test reduce()
for destinations in all_destinations:
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce("mean", per_device, destinations=destinations),
_fake_mirrored(mean, destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce(
"mean", per_device_2, destinations=destinations),
_fake_mirrored(mean_2, destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce("sum", per_device, destinations=destinations),
_fake_mirrored(mean * len(devices), destinations or per_device))
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.reduce(
"sum", per_device_2, destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations or per_device))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.batch_reduce(
"mean", [(per_device, d1), (per_device_2, d2)]),
[_fake_mirrored(mean, d1 or per_device),
_fake_mirrored(mean_2, d2 or per_device_2)])
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.batch_reduce(
"sum", [(per_device, d1), (per_device_2, d2)]),
[_fake_mirrored(mean * len(devices), d1 or per_device),
@@ -176,7 +204,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
if destinations is None:
continue
else:
- self._assert_value_equal(
+ self._assert_values_equal(
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
@@ -184,16 +212,14 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
@@ -202,8 +228,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
[0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
self.assertEqual(result.num_packs, 8)
@@ -211,11 +236,85 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
- self.assertTrue(
- isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
+ self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
self.assertEqual(result.all_reduce_alg, "nccl")
self.assertEqual(result.num_packs, 1)
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testSimpleReduceWithIndexedSlices(self):
+ devices = ["/cpu:0", "/gpu:0"]
+ t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
+ t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
+ per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+ result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
+ math_ops.add_n, "sum")
+
+ # Test that the result is semantically equal to both the concatenated
+ # IndexedSlices with and without duplicate indices.
+ total_with_dups = _make_indexed_slices(
+ [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
+ total_without_dups = _make_indexed_slices(
+ [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
+ self._assert_indexed_slices_equal(total_with_dups, result)
+ self._assert_indexed_slices_equal(total_without_dups, result)
+
+ @combinations.generate(combinations.combine(
+ cross_tower_ops_instance=[
+ combinations.NamedObject(
+ "ReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "AllReduceCrossTowerOps",
+ cross_tower_ops_lib.AllReduceCrossTowerOps())
+ ],
+ method_string=["sum", "mean"],
+ batch_reduce=[True, False],
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
+ method_string, batch_reduce):
+ devices = ["/cpu:0", "/gpu:0"]
+ dense_shape = [5, 2]
+ t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
+ t1 = _make_indexed_slices(
+ [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
+ per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
+
+ if batch_reduce:
+ result = cross_tower_ops_instance.batch_reduce(method_string,
+ [(per_device, devices)])
+ else:
+ result = cross_tower_ops_instance.reduce(method_string, per_device,
+ devices)
+
+ total_indices_with_dups = [1, 1, 3]
+ total_indices_without_dups = [1, 3]
+
+ if method_string == "sum":
+ total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
+ total_values_without_dups = [[4., 6.], [5., 6.]]
+ else:
+ assert method_string == "mean"
+ total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
+ total_values_without_dups = [[2., 3.], [2.5, 3.]]
+
+ total_mirrored_with_dups = _make_mirrored_indexed_slices(
+ devices, total_values_with_dups, total_indices_with_dups, dense_shape)
+ total_mirrored_without_dups = _make_mirrored_indexed_slices(
+ devices, total_values_without_dups, total_indices_without_dups,
+ dense_shape)
+
+ # Test that the result is semantically equal to both the concatenated
+ # IndexedSlices, as well as when the duplicate indices are summed up.
+ if batch_reduce:
+ total_mirrored_with_dups = [total_mirrored_with_dups]
+ total_mirrored_without_dups = [total_mirrored_without_dups]
+
+ self._assert_values_equal(total_mirrored_with_dups, result)
+ self._assert_values_equal(total_mirrored_without_dups, result)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index fc04e2195f..137fabf4c7 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import collections as pycoll
from tensorflow.contrib import nccl
+from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
@@ -337,3 +339,46 @@ def unpack_small_tensors(tower_grads, packing):
new_gv_list.insert(idx, gv[gi])
new_tower_grads.append(new_gv_list)
return new_tower_grads
+
+
+def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
+ """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
+ if any(isinstance(v, ops.IndexedSlices) for v in values):
+ return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access
+ else:
+ return accumulation_fn(values)
+
+
+def divide_by_n_tensors_or_indexed_slices(value, n):
+ if isinstance(value, ops.IndexedSlices):
+ value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access
+ return ops.IndexedSlices(
+ value.values / n, value.indices, value.dense_shape)
+ else:
+ return value / n
+
+
+def copy_tensor_or_indexed_slices_to_device(value, device):
+ with ops.device(device):
+ if isinstance(value, ops.IndexedSlices):
+ copied_values = array_ops.identity(value.values)
+ copied_indices = array_ops.identity(value.indices)
+ copied_shape = array_ops.identity(value.dense_shape)
+ result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
+ else:
+ result = array_ops.identity(value)
+ return result
+
+
+def contains_indexed_slices(value):
+ """Check whether the value is `IndexedSlices` or contains `IndexedSlices`."""
+ if isinstance(value, ops.IndexedSlices):
+ return True
+ elif isinstance(value, (list, tuple)) and value:
+ return any(contains_indexed_slices(v) for v in value)
+ elif isinstance(value, value_lib.DistributedValues):
+ return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access
+ elif isinstance(value, value_lib.MapOutput):
+ return contains_indexed_slices(value.get())
+ else:
+ return False
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
new file mode 100644
index 0000000000..4ef8db6815
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cross_tower_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import device_util
+
+
+class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
+
+ def _assert_values_equal(self, left, right):
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testAggregateTensors(self):
+ t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
+ total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+ result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+ self._assert_values_equal(total, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testAggregateIndexedSlices(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
+ result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(total, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testDivideTensor(self):
+ t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ n = 2
+ expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+ result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+ self._assert_values_equal(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testDivideIndexedSlices(self):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ n = 2
+ expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
+ result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testIsIndexedSlices(self):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_List(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_Tuple(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_PerDevice(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1})
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testContainsIndexedSlices_PerDeviceMapOutput(self):
+ t0 = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(
+ constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
+ per_device = value_lib.PerDevice({
+ "/gpu:0": value_lib.MapOutput([t0]),
+ "/cpu:0": value_lib.MapOutput([t1])})
+ self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
+
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testCopyTensor(self):
+ with ops.device("/cpu:0"):
+ t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
+ destination = "/gpu:0"
+ result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ t, destination)
+
+ self._assert_values_equal(t, result)
+ self.assertEqual(device_util.resolve(destination),
+ device_util.resolve(result.device))
+
+ @combinations.generate(combinations.combine(
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testCopyIndexedSlices(self):
+ with ops.device("/cpu:0"):
+ t = math_ops._as_indexed_slices(
+ constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
+ destination = "/gpu:0"
+ result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
+ t, destination)
+
+ self.assertIsInstance(result, ops.IndexedSlices)
+ self._assert_values_equal(t, result)
+ self.assertEqual(device_util.resolve(destination),
+ device_util.resolve(result.device))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
new file mode 100644
index 0000000000..75ecd90dcf
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -0,0 +1,148 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras Sequential and Functional models."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import keras as keras_lib
+from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import rmsprop
+
+_RANDOM_SEED = 1337
+_TRAIN_SIZE = 200
+_INPUT_SIZE = (10,)
+_NUM_CLASS = 2
+
+
+def simple_sequential_model():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE))
+ model.add(keras.layers.Dropout(0.1))
+ model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax'))
+ return model
+
+
+def simple_functional_model():
+ a = keras.layers.Input(shape=_INPUT_SIZE)
+ b = keras.layers.Dense(16, activation='relu')(a)
+ b = keras.layers.Dropout(0.1)(b)
+ b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)
+ model = keras.models.Model(inputs=[a], outputs=[b])
+ return model
+
+
+def get_ds_train_input_fn():
+ np.random.seed(_RANDOM_SEED)
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=_INPUT_SIZE,
+ num_classes=_NUM_CLASS)
+ y_train = keras.utils.to_categorical(y_train)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
+ dataset = dataset.batch(32)
+ return dataset
+
+
+def get_ds_test_input_fn():
+ np.random.seed(_RANDOM_SEED)
+ _, (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=_INPUT_SIZE,
+ num_classes=_NUM_CLASS)
+ y_test = keras.utils.to_categorical(y_test)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
+ dataset = dataset.batch(32)
+ return dataset
+
+
+class TestKerasDistributionStrategy(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._base_dir = os.path.join(self.get_temp_dir(),
+ 'keras_mirrored_strategy_test')
+ gfile.MakeDirs(self._base_dir)
+ self._config = run_config_lib.RunConfig(
+ tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+
+ def tearDown(self):
+ writer_cache.FileWriterCache.clear()
+ if os.path.isdir(self._base_dir):
+ gfile.DeleteRecursively(self._base_dir)
+
+ def test_train_functional_with_distribution_strategy(self):
+ dist = mirrored_strategy.MirroredStrategy(
+ devices=['/device:GPU:0', '/device:GPU:1'])
+ keras_model = simple_functional_model()
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
+ config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
+ model_dir=self._base_dir,
+ train_distribute=dist)
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=config)
+ before_eval_results = est_keras.evaluate(
+ input_fn=get_ds_test_input_fn, steps=1)
+ est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
+ steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ writer_cache.FileWriterCache.clear()
+ gfile.DeleteRecursively(self._config.model_dir)
+
+ def test_train_sequential_with_distribution_strategy(self):
+ dist = mirrored_strategy.MirroredStrategy(
+ devices=['/device:GPU:0', '/device:GPU:1'])
+ keras_model = simple_sequential_model()
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
+ config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
+ model_dir=self._base_dir,
+ train_distribute=dist)
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=config)
+ before_eval_results = est_keras.evaluate(
+ input_fn=get_ds_test_input_fn, steps=1)
+ est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
+ steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ writer_cache.FileWriterCache.clear()
+ gfile.DeleteRecursively(self._config.model_dir)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 89f2c431fe..14dbbd6e27 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
import threading
import six
@@ -39,6 +40,16 @@ from tensorflow.python.training import distribute as distribute_lib
# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+@contextlib.contextmanager
+def _enter_graph(g):
+ if context.executing_eagerly():
+ with g.as_default(), context.eager_mode():
+ yield
+ else:
+ with g.as_default():
+ yield
+
+
def _cpu_device(device):
cpu_device = tf_device.DeviceSpec.from_string(device)
cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
@@ -458,7 +469,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with self.coord.stop_on_exception(), \
context.context()._mode(self.context_mode), \
context.context().device_policy(self.context_device_policy), \
- self.graph.as_default(), \
+ _enter_graph(self.graph), \
MirroredTowerContext(self.distribution, self.tower_id), \
ops.device(self.device), \
ops.name_scope(self._captured_name_scope), \
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 8277e1e791..4fdb9bf69b 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import monitor as monitor_lib
from tensorflow.contrib.distribute.python import one_device_strategy
from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
@@ -65,7 +66,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
step_function, _ = single_loss_example(
lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution)
- with self.test_session() as sess:
+ with session.Session() as sess, context.eager_mode():
with self.assertRaisesRegexp(ValueError, "Should not provide"):
_ = monitor_lib.Monitor(step_function, sess)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
index 8b279ebcd9..f8a52615b0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
@@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase):
for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]:
method = getattr(b, name)
with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
- method(1., event_ndims=0., arg1="b1", arg2="b2")
+ method(1., event_ndims=0, arg1="b1", arg2="b2")
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
index ce6cf702d5..9c4dfed836 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
@@ -98,23 +98,21 @@ class StatisticalTestingTest(test.TestCase):
num_samples = 5000
# 5000 samples is chosen to be enough to find discrepancies of
# size 0.1 or more with assurance 1e-6, as confirmed here:
- with self.test_session() as sess:
- d = st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6)
- d = sess.run(d)
- self.assertLess(d, 0.1)
+ d = st.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6)
+ d = self.evaluate(d)
+ self.assertLess(d, 0.1)
# Test that the confidence interval computed for the mean includes
# 0.5 and excludes 0.4 and 0.6.
- with self.test_session() as sess:
- samples = rng.uniform(size=num_samples).astype(np.float32)
- (low, high) = st.true_mean_confidence_interval_by_dkwm(
- samples, 0., 1., error_rate=1e-6)
- low, high = sess.run([low, high])
- self.assertGreater(low, 0.4)
- self.assertLess(low, 0.5)
- self.assertGreater(high, 0.5)
- self.assertLess(high, 0.6)
+ samples = rng.uniform(size=num_samples).astype(np.float32)
+ (low, high) = st.true_mean_confidence_interval_by_dkwm(
+ samples, 0., 1., error_rate=1e-6)
+ low, high = self.evaluate([low, high])
+ self.assertGreater(low, 0.4)
+ self.assertLess(low, 0.5)
+ self.assertGreater(high, 0.5)
+ self.assertLess(high, 0.6)
def test_dkwm_mean_one_sample_assertion(self):
rng = np.random.RandomState(seed=0)
@@ -123,21 +121,45 @@ class StatisticalTestingTest(test.TestCase):
# Test that the test assertion agrees that the mean of the standard
# uniform distribution is 0.5.
samples = rng.uniform(size=num_samples).astype(np.float32)
- with self.test_session() as sess:
- sess.run(st.assert_true_mean_equal_by_dkwm(
- samples, 0., 1., 0.5, false_fail_rate=1e-6))
-
- # Test that the test assertion confirms that the mean of the
- # standard uniform distribution is not 0.4.
- with self.assertRaisesOpError("Mean confidence interval too high"):
- sess.run(st.assert_true_mean_equal_by_dkwm(
- samples, 0., 1., 0.4, false_fail_rate=1e-6))
-
- # Test that the test assertion confirms that the mean of the
- # standard uniform distribution is not 0.6.
- with self.assertRaisesOpError("Mean confidence interval too low"):
- sess.run(st.assert_true_mean_equal_by_dkwm(
- samples, 0., 1., 0.6, false_fail_rate=1e-6))
+ self.evaluate(st.assert_true_mean_equal_by_dkwm(
+ samples, 0., 1., 0.5, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not 0.4.
+ with self.assertRaisesOpError("true mean greater than expected"):
+ self.evaluate(st.assert_true_mean_equal_by_dkwm(
+ samples, 0., 1., 0.4, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not 0.6.
+ with self.assertRaisesOpError("true mean smaller than expected"):
+ self.evaluate(st.assert_true_mean_equal_by_dkwm(
+ samples, 0., 1., 0.6, false_fail_rate=1e-6))
+
+ def test_dkwm_mean_in_interval_one_sample_assertion(self):
+ rng = np.random.RandomState(seed=0)
+ num_samples = 5000
+
+ # Test that the test assertion agrees that the mean of the standard
+ # uniform distribution is between 0.4 and 0.6.
+ samples = rng.uniform(size=num_samples).astype(np.float32)
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not between 0.2 and 0.4.
+ with self.assertRaisesOpError("true mean greater than expected"):
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not between 0.6 and 0.8.
+ with self.assertRaisesOpError("true mean smaller than expected"):
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6))
def test_dkwm_mean_two_sample_assertion(self):
rng = np.random.RandomState(seed=0)
@@ -145,20 +167,18 @@ class StatisticalTestingTest(test.TestCase):
# 4000 samples is chosen to be enough to find discrepancies of
# size 0.2 or more with assurance 1e-6, as confirmed here:
- with self.test_session() as sess:
- d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
- num_samples, 0., 1., num_samples, 0., 1.,
- false_fail_rate=1e-6, false_pass_rate=1e-6)
- d = sess.run(d)
- self.assertLess(d, 0.2)
+ d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
+ num_samples, 0., 1., num_samples, 0., 1.,
+ false_fail_rate=1e-6, false_pass_rate=1e-6)
+ d = self.evaluate(d)
+ self.assertLess(d, 0.2)
# Test that the test assertion agrees that the standard
# uniform distribution has the same mean as itself.
samples1 = rng.uniform(size=num_samples).astype(np.float32)
samples2 = rng.uniform(size=num_samples).astype(np.float32)
- with self.test_session() as sess:
- sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
- samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6))
+ self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample(
+ samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6))
def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self):
rng = np.random.RandomState(seed=0)
@@ -168,15 +188,14 @@ class StatisticalTestingTest(test.TestCase):
# As established above, 4000 samples is enough to find discrepancies
# of size 0.2 or more with assurance 1e-6.
- with self.test_session() as sess:
- # Test that the test assertion confirms that the mean of the
- # standard uniform distribution is different from the mean of beta(2, 1).
- beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32)
- with self.assertRaisesOpError("samples1 has a smaller mean"):
- sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
- samples1, 0., 1.,
- beta_high_samples, 0., 1.,
- false_fail_rate=1e-6))
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is different from the mean of beta(2, 1).
+ beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32)
+ with self.assertRaisesOpError("true mean smaller than expected"):
+ self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample(
+ samples1, 0., 1.,
+ beta_high_samples, 0., 1.,
+ false_fail_rate=1e-6))
def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self):
rng = np.random.RandomState(seed=0)
@@ -186,15 +205,14 @@ class StatisticalTestingTest(test.TestCase):
# As established above, 4000 samples is enough to find discrepancies
# of size 0.2 or more with assurance 1e-6.
- with self.test_session() as sess:
- # Test that the test assertion confirms that the mean of the
- # standard uniform distribution is different from the mean of beta(1, 2).
- beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32)
- with self.assertRaisesOpError("samples2 has a smaller mean"):
- sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
- samples1, 0., 1.,
- beta_low_samples, 0., 1.,
- false_fail_rate=1e-6))
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is different from the mean of beta(1, 2).
+ beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32)
+ with self.assertRaisesOpError("true mean greater than expected"):
+ self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample(
+ samples1, 0., 1.,
+ beta_low_samples, 0., 1.,
+ false_fail_rate=1e-6))
def test_dkwm_argument_validity_checking(self):
rng = np.random.RandomState(seed=0)
@@ -203,18 +221,17 @@ class StatisticalTestingTest(test.TestCase):
# Test that the test library complains if the given samples fall
# outside the purported bounds.
- with self.test_session() as sess:
- with self.assertRaisesOpError("maximum value exceeds expectations"):
- sess.run(st.true_mean_confidence_interval_by_dkwm(
- samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5))
- with self.assertRaisesOpError("minimum value falls below expectations"):
- sess.run(st.true_mean_confidence_interval_by_dkwm(
- samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5))
-
- # But doesn't complain if they don't.
- op = st.true_mean_confidence_interval_by_dkwm(
- samples, [[0., 1.]], [[1., 2.]], error_rate=0.5)
- _ = sess.run(op)
+ with self.assertRaisesOpError("maximum value exceeds expectations"):
+ self.evaluate(st.true_mean_confidence_interval_by_dkwm(
+ samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5))
+ with self.assertRaisesOpError("minimum value falls below expectations"):
+ self.evaluate(st.true_mean_confidence_interval_by_dkwm(
+ samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5))
+
+ # But doesn't complain if they don't.
+ op = st.true_mean_confidence_interval_by_dkwm(
+ samples, [[0., 1.]], [[1., 2.]], error_rate=0.5)
+ _ = self.evaluate(op)
def test_do_maximum_mean(self):
n = 117
@@ -223,10 +240,9 @@ class StatisticalTestingTest(test.TestCase):
samples = rng.uniform(size=n).astype(np.float32)
# Compute the answer in TF using the code under test
- with self.test_session() as sess:
- envelope_t = ops.convert_to_tensor(envelope)
- max_mean = st._do_maximum_mean(samples, envelope_t, 1)
- max_mean = sess.run(max_mean)
+ envelope_t = ops.convert_to_tensor(envelope)
+ max_mean = st._do_maximum_mean(samples, envelope_t, 1)
+ max_mean = self.evaluate(max_mean)
# Compute the correct answer for this case in numpy. In this
# example, `n` and `envelope` are such that `samples[2]` is the
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index d813831bef..11ca90c483 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -144,7 +144,7 @@ class Autoregressive(distribution_lib.Distribution):
`distribution_fn(sample0).event_shape.num_elements()` are both `None`.
ValueError: if `num_steps < 1`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
self._distribution_fn = distribution_fn
self._sample0 = sample0
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index c709318f76..4714caad69 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -103,7 +102,7 @@ class BatchReshape(distribution_lib.Distribution):
ValueError: if `batch_shape` size is not the same as a
`distribution.batch_shape` size.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
name = name or "BatchReshape" + distribution.name
with ops.name_scope(name, values=[batch_shape]) as name:
# The unexpanded batch shape may contain up to one dimension of -1.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index b158a51bb0..16f959560c 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -234,7 +234,7 @@ class Chain(bijector.Bijector):
if not self.bijectors:
return ildj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
@@ -248,12 +248,15 @@ class Chain(bijector.Bijector):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
+
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
@@ -270,7 +273,7 @@ class Chain(bijector.Bijector):
if not self.bijectors:
return fldj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
@@ -283,21 +286,14 @@ class Chain(bijector.Bijector):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
-
- def _maybe_get_event_ndims_statically(self, event_ndims):
- event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
- event_ndims)
- if event_ndims_ is None:
- return event_ndims
- return event_ndims_
-
-
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
index 24b26bf124..e4944beedc 100644
--- a/tensorflow/contrib/distributions/python/ops/binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -163,7 +163,7 @@ class Binomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = self._maybe_assert_valid_total_count(
ops.convert_to_tensor(total_count, name="total_count"),
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index f5ffdd8731..23b6a83c17 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -29,7 +29,6 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Cauchy",
@@ -121,7 +120,7 @@ class Cauchy(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index 08cdc15828..686ae1ba74 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -25,7 +25,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import gamma
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -84,7 +83,7 @@ class Chi2(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
# Even though all stats of chi2 are defined for valid parameters, this is
# not true in the parent class "gamma." therefore, passing
# allow_nan_stats=True
@@ -120,7 +119,7 @@ class Chi2WithAbsDf(Chi2):
validate_args=False,
allow_nan_stats=True,
name="Chi2WithAbsDf"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[df]) as name:
super(Chi2WithAbsDf, self).__init__(
df=math_ops.floor(
diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
index 10b4536135..3598c8d23e 100644
--- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import transformed_distribution
@@ -106,7 +105,7 @@ class ConditionalTransformedDistribution(
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
@@ -131,7 +130,7 @@ class ConditionalTransformedDistribution(
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
@@ -220,14 +219,14 @@ class ConditionalTransformedDistribution(
inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
return self.bijector.forward(inv_cdf, **bijector_kwargs)
- def _maybe_get_event_ndims_statically(self):
+ def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
- static_event_ndims = tensor_util.constant_value(event_ndims)
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- if static_event_ndims is not None:
- return static_event_ndims
+ if event_ndims_ is not None:
+ return event_ndims_
return event_ndims
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index 6d7d6d307b..c44c76a133 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -32,7 +32,6 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"Deterministic",
@@ -87,7 +86,7 @@ class _BaseDeterministic(distribution.Distribution):
Raises:
ValueError: If `loc` is a scalar.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, atol, rtol]) as name:
loc = ops.convert_to_tensor(loc, name="loc")
if is_vector and validate_args:
diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py
index 446cff6ec2..e1e42ee95d 100644
--- a/tensorflow/contrib/distributions/python/ops/geometric.py
+++ b/tensorflow/contrib/distributions/python/ops/geometric.py
@@ -85,7 +85,7 @@ class Geometric(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index ed9ea6f4f3..9d94fd11c6 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -29,7 +29,6 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
class _Gumbel(distribution.Distribution):
@@ -125,7 +124,7 @@ class _Gumbel(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index 7e12767f6d..9c96254d1c 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -31,7 +31,6 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -106,7 +105,7 @@ class HalfNormal(distribution.Distribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index fa89fff3b7..cd6eaa8407 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -29,7 +29,6 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
-from tensorflow.python.ops.distributions import util as distribution_util
class Independent(distribution_lib.Distribution):
@@ -117,7 +116,7 @@ class Independent(distribution_lib.Distribution):
ValueError: if `reinterpreted_batch_ndims` exceeds
`distribution.batch_ndims`
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
name = name or "Independent" + distribution.name
self._distribution = distribution
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 85e8e10466..208057b34d 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -125,7 +125,7 @@ class InverseGamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -280,7 +280,7 @@ class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
validate_args=False,
allow_nan_stats=True,
name="InverseGammaWithSoftplusConcentrationRate"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration, rate]) as name:
super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 0103283259..27aa863440 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -31,7 +31,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
class Logistic(distribution.Distribution):
@@ -120,7 +119,7 @@ class Logistic(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index d54f30dc63..bfb53a06c0 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -116,7 +116,7 @@ class Mixture(distribution.Distribution):
matching static batch shapes, or all components do not
have matching static event shapes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
if not isinstance(cat, categorical.Categorical):
raise TypeError("cat must be a Categorical distribution, but saw: %s" %
cat)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index c7c90cf875..112eefd369 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -130,7 +130,7 @@ class MixtureSameFamily(distribution.Distribution):
ValueError: if `mixture_distribution` categories does not equal
`components_distribution` rightmost batch shape.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
self._mixture_distribution = mixture_distribution
self._components_distribution = components_distribution
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index cad398582b..d2beb2aff0 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -193,7 +193,7 @@ class MultivariateNormalDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
@@ -224,7 +224,7 @@ class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
validate_args=False,
allow_nan_stats=True,
name="MultivariateNormalDiagWithSoftplusScale"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[scale_diag]) as name:
super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 1c11594df3..5117379b04 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -215,7 +215,7 @@ class MultivariateNormalDiagPlusLowRank(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 47d7d13cf3..57f47db50c 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
@@ -156,7 +155,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
Raises:
ValueError: if neither `loc` nor `covariance_matrix` are specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
# Convert the covariance_matrix up to a scale_tril and call MVNTriL.
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index 79916fef8d..6a0383db02 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -170,7 +170,7 @@ class MultivariateNormalLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d6b0ed994e..c809ef3c1c 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -179,7 +179,7 @@ class MultivariateNormalTriL(
Raises:
ValueError: if neither `loc` nor `scale_tril` are specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
def _convert_to_tensor(x, name):
return None if x is None else ops.convert_to_tensor(x, name=name)
if loc is None and scale_tril is None:
diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
index 1085c56dc8..2bd11e24b3 100644
--- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
@@ -90,7 +90,7 @@ class NegativeBinomial(distribution.Distribution):
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits, probs, validate_args=validate_args, name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
index a4b9f3b78d..3e44c10fab 100644
--- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -115,7 +115,7 @@ class OneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
name=name, logits=logits, probs=probs, validate_args=validate_args,
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index b345394021..04de8106ee 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -93,7 +93,7 @@ class Poisson(distribution.Distribution):
TypeError: if `rate` is not a float-type.
TypeError: if `log_rate` is not a float-type.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[rate]) as name:
if (rate is None) == (log_rate is None):
raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index fe72091d7d..7b10ba998f 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -255,7 +255,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
TypeError: if `quadrature_grid` and `quadrature_probs` have different base
`dtype`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
if loc is not None:
loc = ops.convert_to_tensor(loc, name="loc")
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 584d2c385f..5ac6c34b53 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -263,7 +263,7 @@ class QuantizedDistribution(distributions.Distribution):
`Distribution` or continuous.
NotImplementedError: If the base distribution does not implement `cdf`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
values = (
list(distribution.parameters.values()) +
[low, high])
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 0362996e68..4182ca2b56 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -165,7 +165,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
Raises:
ValueError: If both `probs` and `logits` are passed, or if neither.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
with ops.control_dependencies([check_ops.assert_positive(temperature)]
if validate_args else []):
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 910c430ae7..5414f347cd 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -162,7 +162,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs, temperature]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index f04dc8da39..a764544932 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -132,7 +132,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name,
values=[loc, scale, skewness, tailweight]) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index 9c69435fac..c25e8c51d7 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -140,6 +140,7 @@ __all__ = [
"assert_true_mean_equal_by_dkwm",
"min_discrepancy_of_true_means_detectable_by_dkwm",
"min_num_samples_for_dkwm_mean_test",
+ "assert_true_mean_in_interval_by_dkwm",
"assert_true_mean_equal_by_dkwm_two_sample",
"min_discrepancy_of_true_means_detectable_by_dkwm_two_sample",
"min_num_samples_for_dkwm_mean_two_sample_test",
@@ -209,17 +210,17 @@ def _maximum_mean(samples, envelope, high, name=None):
separately.
Args:
- samples: Floating-point tensor of samples from the distribution(s)
+ samples: Floating-point `Tensor` of samples from the distribution(s)
of interest. Entries are assumed IID across the 0th dimension.
The other dimensions must broadcast with `envelope` and `high`.
- envelope: Floating-point tensor of sizes of admissible CDF
+ envelope: Floating-point `Tensor` of sizes of admissible CDF
envelopes (i.e., the `eps` above).
- high: Floating-point tensor of upper bounds on the distributions'
- supports.
+ high: Floating-point `Tensor` of upper bounds on the distributions'
+ supports. `samples <= high`.
name: A name for this operation (optional).
Returns:
- bound: Floating-point tensor of upper bounds on the true means.
+ bound: Floating-point `Tensor` of upper bounds on the true means.
Raises:
InvalidArgumentError: If some `sample` is found to be larger than
@@ -254,17 +255,17 @@ def _minimum_mean(samples, envelope, low, name=None):
separately.
Args:
- samples: Floating-point tensor of samples from the distribution(s)
+ samples: Floating-point `Tensor` of samples from the distribution(s)
of interest. Entries are assumed IID across the 0th dimension.
The other dimensions must broadcast with `envelope` and `low`.
- envelope: Floating-point tensor of sizes of admissible CDF
+ envelope: Floating-point `Tensor` of sizes of admissible CDF
envelopes (i.e., the `eps` above).
- low: Floating-point tensor of lower bounds on the distributions'
- supports.
+ low: Floating-point `Tensor` of lower bounds on the distributions'
+ supports. `samples >= low`.
name: A name for this operation (optional).
Returns:
- bound: Floating-point tensor of lower bounds on the true means.
+ bound: Floating-point `Tensor` of lower bounds on the true means.
Raises:
InvalidArgumentError: If some `sample` is found to be smaller than
@@ -300,12 +301,12 @@ def _dkwm_cdf_envelope(n, error_rate, name=None):
probability above.
Args:
- n: Tensor of numbers of samples drawn.
- error_rate: Floating-point tensor of admissible rates of mistakes.
+ n: `Tensor` of numbers of samples drawn.
+ error_rate: Floating-point `Tensor` of admissible rates of mistakes.
name: A name for this operation (optional).
Returns:
- eps: Tensor of maximum distances the true CDF can be from the
+ eps: `Tensor` of maximum distances the true CDF can be from the
empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and
as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and
`error_rate`.
@@ -324,8 +325,8 @@ def _check_shape_dominates(samples, parameters):
sample counts end up inflated.
Args:
- samples: A Tensor whose shape is to be protected against broadcasting.
- parameters: A list of Tensors who are parameters for the statistical test.
+ samples: A `Tensor` whose shape is to be protected against broadcasting.
+ parameters: A list of `Tensor`s who are parameters for the statistical test.
Returns:
samples: Return original `samples` with control dependencies attached
@@ -369,19 +370,23 @@ def true_mean_confidence_interval_by_dkwm(
members.
Args:
- samples: Floating-point tensor of samples from the distribution(s)
+ samples: Floating-point `Tensor` of samples from the distribution(s)
of interest. Entries are assumed IID across the 0th dimension.
The other dimensions must broadcast with `low` and `high`.
- low: Floating-point tensor of lower bounds on the distributions'
+ The support is bounded: `low <= samples <= high`.
+ low: Floating-point `Tensor` of lower bounds on the distributions'
supports.
- high: Floating-point tensor of upper bounds on the distributions'
+ high: Floating-point `Tensor` of upper bounds on the distributions'
supports.
- error_rate: *Scalar* admissible total rate of mistakes.
+ error_rate: *Scalar* floating-point `Tensor` admissible total rate
+ of mistakes.
name: A name for this operation (optional).
Returns:
- low: A floating-point tensor of stochastic lower bounds on the true means.
- high: A floating-point tensor of stochastic upper bounds on the true means.
+ low: A floating-point `Tensor` of stochastic lower bounds on the
+ true means.
+ high: A floating-point `Tensor` of stochastic upper bounds on the
+ true means.
"""
with ops.name_scope(
name, "true_mean_confidence_interval_by_dkwm",
@@ -436,15 +441,17 @@ def assert_true_mean_equal_by_dkwm(
the assertion will insist on stronger evidence to fail any one member.
Args:
- samples: Floating-point tensor of samples from the distribution(s)
+ samples: Floating-point `Tensor` of samples from the distribution(s)
of interest. Entries are assumed IID across the 0th dimension.
The other dimensions must broadcast with `low` and `high`.
- low: Floating-point tensor of lower bounds on the distributions'
+ The support is bounded: `low <= samples <= high`.
+ low: Floating-point `Tensor` of lower bounds on the distributions'
supports.
- high: Floating-point tensor of upper bounds on the distributions'
+ high: Floating-point `Tensor` of upper bounds on the distributions'
supports.
- expected: Floating-point tensor of expected true means.
- false_fail_rate: *Scalar* admissible total rate of mistakes.
+ expected: Floating-point `Tensor` of expected true means.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of mistakes.
name: A name for this operation (optional).
Returns:
@@ -454,20 +461,8 @@ def assert_true_mean_equal_by_dkwm(
with ops.name_scope(
name, "assert_true_mean_equal_by_dkwm",
[samples, low, high, expected, false_fail_rate]):
- samples = ops.convert_to_tensor(samples, name="samples")
- low = ops.convert_to_tensor(low, name="low")
- high = ops.convert_to_tensor(high, name="high")
- expected = ops.convert_to_tensor(expected, name="expected")
- false_fail_rate = ops.convert_to_tensor(
- false_fail_rate, name="false_fail_rate")
- samples = _check_shape_dominates(samples, [low, high, expected])
- min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
- samples, low, high, error_rate=false_fail_rate)
- less_op = check_ops.assert_less(
- min_mean, expected, message="Mean confidence interval too high")
- with ops.control_dependencies([less_op]):
- return check_ops.assert_greater(
- max_mean, expected, message="Mean confidence interval too low")
+ return assert_true_mean_in_interval_by_dkwm(
+ samples, low, high, expected, expected, false_fail_rate)
def min_discrepancy_of_true_means_detectable_by_dkwm(
@@ -487,30 +482,35 @@ def min_discrepancy_of_true_means_detectable_by_dkwm(
with the same `false_pass_rate`.
Args:
- n: Tensor of numbers of samples to be drawn from the distributions
+ n: `Tensor` of numbers of samples to be drawn from the distributions
of interest.
- low: Floating-point tensor of lower bounds on the distributions'
+ low: Floating-point `Tensor` of lower bounds on the distributions'
supports.
- high: Floating-point tensor of upper bounds on the distributions'
+ high: Floating-point `Tensor` of upper bounds on the distributions'
supports.
- false_fail_rate: *Scalar* admissible total rate of false failures.
- false_pass_rate: *Scalar* admissible rate of false passes.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of false failures.
+ false_pass_rate: *Scalar* floating-point `Tensor` admissible rate
+ of false passes.
name: A name for this operation (optional).
Returns:
- discr: Tensor of lower bounds on the distances between true
+ discr: `Tensor` of lower bounds on the distances between true
means detectable by a DKWM-based test.
For each batch member `i`, of `K` total, drawing `n[i]` samples from
some scalar distribution supported on `[low[i], high[i]]` is enough
to detect a difference in means of size `discr[i]` or more.
Specifically, we guarantee that (a) if the true mean is the expected
- mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
- most `false_fail_rate / K` (which amounts to `false_fail_rate` if
- applied to the whole batch at once), and (b) if the true mean
- differs from the expected mean by at least `discr[i]`,
- `assert_true_mean_equal_by_dkwm` will pass with probability at most
- `false_pass_rate`.
+ mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+ probability at most `false_fail_rate / K` (which amounts to
+ `false_fail_rate` if applied to the whole batch at once), and (b) if
+ the true mean differs from the expected mean (resp. falls outside
+ the expected interval) by at least `discr[i]`,
+ `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+ probability at most `false_pass_rate`.
The detectable discrepancy scales as
@@ -558,17 +558,19 @@ def min_num_samples_for_dkwm_mean_test(
on a scalar distribution supported on `[low, high]`.
Args:
- discrepancy: Floating-point tensor of desired upper limits on mean
+ discrepancy: Floating-point `Tensor` of desired upper limits on mean
differences that may go undetected with probability higher than
`1 - false_pass_rate`.
- low: Tensor of lower bounds on the distributions' support.
- high: Tensor of upper bounds on the distributions' support.
- false_fail_rate: *Scalar* admissible total rate of false failures.
- false_pass_rate: *Scalar* admissible rate of false passes.
+ low: `Tensor` of lower bounds on the distributions' support.
+ high: `Tensor` of upper bounds on the distributions' support.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of false failures.
+ false_pass_rate: *Scalar* floating-point `Tensor` admissible rate
+ of false passes.
name: A name for this operation (optional).
Returns:
- n: Tensor of numbers of samples to be drawn from the distributions
+ n: `Tensor` of numbers of samples to be drawn from the distributions
of interest.
The `discrepancy`, `low`, and `high` tensors must have
@@ -578,12 +580,15 @@ def min_num_samples_for_dkwm_mean_test(
some scalar distribution supported on `[low[i], high[i]]` is enough
to detect a difference in means of size `discrepancy[i]` or more.
Specifically, we guarantee that (a) if the true mean is the expected
- mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
- most `false_fail_rate / K` (which amounts to `false_fail_rate` if
- applied to the whole batch at once), and (b) if the true mean
- differs from the expected mean by at least `discrepancy[i]`,
- `assert_true_mean_equal_by_dkwm` will pass with probability at most
- `false_pass_rate`.
+ mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+ probability at most `false_fail_rate / K` (which amounts to
+ `false_fail_rate` if applied to the whole batch at once), and (b) if
+ the true mean differs from the expected mean (resp. falls outside
+ the expected interval) by at least `discrepancy[i]`,
+ `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+ probability at most `false_pass_rate`.
The required number of samples scales
as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`,
@@ -610,6 +615,76 @@ def min_num_samples_for_dkwm_mean_test(
return math_ops.maximum(n1, n2)
+def assert_true_mean_in_interval_by_dkwm(
+ samples, low, high, expected_low, expected_high,
+ false_fail_rate=1e-6, name=None):
+ """Asserts the mean of the given distribution is in the given interval.
+
+ More precisely, fails if there is enough evidence (using the
+ [Dvoretzky-Kiefer-Wolfowitz-Massart inequality]
+ (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval))
+ that the mean of the distribution from which the given samples are
+ drawn is _outside_ the given interval with statistical significance
+ `false_fail_rate` or stronger, otherwise passes. If you also want
+ to check that you are gathering enough evidence that a pass is not
+ spurious, see `min_num_samples_for_dkwm_mean_test` and
+ `min_discrepancy_of_true_means_detectable_by_dkwm`.
+
+ Note that `false_fail_rate` is a total false failure rate for all
+ the assertions in the batch. As such, if the batch is nontrivial,
+ the assertion will insist on stronger evidence to fail any one member.
+
+ Args:
+ samples: Floating-point `Tensor` of samples from the distribution(s)
+ of interest. Entries are assumed IID across the 0th dimension.
+ The other dimensions must broadcast with `low` and `high`.
+ The support is bounded: `low <= samples <= high`.
+ low: Floating-point `Tensor` of lower bounds on the distributions'
+ supports.
+ high: Floating-point `Tensor` of upper bounds on the distributions'
+ supports.
+ expected_low: Floating-point `Tensor` of lower bounds on the
+ expected true means.
+ expected_high: Floating-point `Tensor` of upper bounds on the
+ expected true means.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of mistakes.
+ name: A name for this operation (optional).
+
+ Returns:
+ check: Op that raises `InvalidArgumentError` if any expected mean
+ interval does not overlap with the corresponding confidence
+ interval.
+ """
+ with ops.name_scope(
+ name, "assert_true_mean_in_interval_by_dkwm",
+ [samples, low, high, expected_low, expected_high, false_fail_rate]):
+ samples = ops.convert_to_tensor(samples, name="samples")
+ low = ops.convert_to_tensor(low, name="low")
+ high = ops.convert_to_tensor(high, name="high")
+ expected_low = ops.convert_to_tensor(expected_low, name="expected_low")
+ expected_high = ops.convert_to_tensor(expected_high, name="expected_high")
+ false_fail_rate = ops.convert_to_tensor(
+ false_fail_rate, name="false_fail_rate")
+ samples = _check_shape_dominates(
+ samples, [low, high, expected_low, expected_high])
+ min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
+ samples, low, high, false_fail_rate)
+ # Assert that the interval [min_mean, max_mean] intersects the
+ # interval [expected_low, expected_high]. This is true if
+ # max_mean >= expected_low and min_mean <= expected_high.
+ # By DeMorgan's law, that's also equivalent to
+ # not (max_mean < expected_low or min_mean > expected_high),
+ # which is a way of saying the two intervals are not disjoint.
+ check_confidence_interval_can_intersect = check_ops.assert_greater_equal(
+ max_mean, expected_low, message="Confidence interval does not "
+ "intersect: true mean smaller than expected")
+ with ops.control_dependencies([check_confidence_interval_can_intersect]):
+ return check_ops.assert_less_equal(
+ min_mean, expected_high, message="Confidence interval does not "
+ "intersect: true mean greater than expected")
+
+
def assert_true_mean_equal_by_dkwm_two_sample(
samples1, low1, high1, samples2, low2, high2,
false_fail_rate=1e-6, name=None):
@@ -630,23 +705,26 @@ def assert_true_mean_equal_by_dkwm_two_sample(
the assertion will insist on stronger evidence to fail any one member.
Args:
- samples1: Floating-point tensor of samples from the
+ samples1: Floating-point `Tensor` of samples from the
distribution(s) A. Entries are assumed IID across the 0th
dimension. The other dimensions must broadcast with `low1`,
`high1`, `low2`, and `high2`.
- low1: Floating-point tensor of lower bounds on the supports of the
+ The support is bounded: `low1 <= samples1 <= high1`.
+ low1: Floating-point `Tensor` of lower bounds on the supports of the
distributions A.
- high1: Floating-point tensor of upper bounds on the supports of
+ high1: Floating-point `Tensor` of upper bounds on the supports of
the distributions A.
- samples2: Floating-point tensor of samples from the
+ samples2: Floating-point `Tensor` of samples from the
distribution(s) B. Entries are assumed IID across the 0th
dimension. The other dimensions must broadcast with `low1`,
`high1`, `low2`, and `high2`.
- low2: Floating-point tensor of lower bounds on the supports of the
+ The support is bounded: `low2 <= samples2 <= high2`.
+ low2: Floating-point `Tensor` of lower bounds on the supports of the
distributions B.
- high2: Floating-point tensor of upper bounds on the supports of
+ high2: Floating-point `Tensor` of upper bounds on the supports of
the distributions B.
- false_fail_rate: *Scalar* admissible total rate of mistakes.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of mistakes.
name: A name for this operation (optional).
Returns:
@@ -676,20 +754,10 @@ def assert_true_mean_equal_by_dkwm_two_sample(
# and sample counts should be valid; however, because the intervals
# scale as O(-log(false_fail_rate)), there doesn't seem to be much
# room to win.
- min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm(
- samples1, low1, high1, false_fail_rate / 2.)
min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm(
samples2, low2, high2, false_fail_rate / 2.)
- # I want to assert
- # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2),
- # but I think I only have and-combination of asserts, so use DeMorgan.
- check_confidence_intervals_can_intersect = check_ops.assert_greater_equal(
- max_mean_1, min_mean_2, message="Confidence intervals do not "
- "intersect: samples1 has a smaller mean than samples2")
- with ops.control_dependencies([check_confidence_intervals_can_intersect]):
- return check_ops.assert_less_equal(
- min_mean_1, max_mean_2, message="Confidence intervals do not "
- "intersect: samples2 has a smaller mean than samples1")
+ return assert_true_mean_in_interval_by_dkwm(
+ samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.)
def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
@@ -710,22 +778,24 @@ def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(
with the same `false_pass_rate`.
Args:
- n1: Tensor of numbers of samples to be drawn from the distributions A.
- low1: Floating-point tensor of lower bounds on the supports of the
+ n1: `Tensor` of numbers of samples to be drawn from the distributions A.
+ low1: Floating-point `Tensor` of lower bounds on the supports of the
distributions A.
- high1: Floating-point tensor of upper bounds on the supports of
+ high1: Floating-point `Tensor` of upper bounds on the supports of
the distributions A.
- n2: Tensor of numbers of samples to be drawn from the distributions B.
- low2: Floating-point tensor of lower bounds on the supports of the
+ n2: `Tensor` of numbers of samples to be drawn from the distributions B.
+ low2: Floating-point `Tensor` of lower bounds on the supports of the
distributions B.
- high2: Floating-point tensor of upper bounds on the supports of
+ high2: Floating-point `Tensor` of upper bounds on the supports of
the distributions B.
- false_fail_rate: *Scalar* admissible total rate of false failures.
- false_pass_rate: *Scalar* admissible rate of false passes.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of false failures.
+ false_pass_rate: *Scalar* floating-point `Tensor` admissible rate
+ of false passes.
name: A name for this operation (optional).
Returns:
- discr: Tensor of lower bounds on the distances between true means
+ discr: `Tensor` of lower bounds on the distances between true means
detectable by a two-sample DKWM-based test.
For each batch member `i`, of `K` total, drawing `n1[i]` samples
@@ -776,24 +846,26 @@ def min_num_samples_for_dkwm_mean_two_sample_test(
(https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval).
Args:
- discrepancy: Floating-point tensor of desired upper limits on mean
+ discrepancy: Floating-point `Tensor` of desired upper limits on mean
differences that may go undetected with probability higher than
`1 - false_pass_rate`.
- low1: Floating-point tensor of lower bounds on the supports of the
+ low1: Floating-point `Tensor` of lower bounds on the supports of the
distributions A.
- high1: Floating-point tensor of upper bounds on the supports of
+ high1: Floating-point `Tensor` of upper bounds on the supports of
the distributions A.
- low2: Floating-point tensor of lower bounds on the supports of the
+ low2: Floating-point `Tensor` of lower bounds on the supports of the
distributions B.
- high2: Floating-point tensor of upper bounds on the supports of
+ high2: Floating-point `Tensor` of upper bounds on the supports of
the distributions B.
- false_fail_rate: *Scalar* admissible total rate of false failures.
- false_pass_rate: *Scalar* admissible rate of false passes.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of false failures.
+ false_pass_rate: *Scalar* floating-point `Tensor` admissible rate
+ of false passes.
name: A name for this operation (optional).
Returns:
- n1: Tensor of numbers of samples to be drawn from the distributions A.
- n2: Tensor of numbers of samples to be drawn from the distributions B.
+ n1: `Tensor` of numbers of samples to be drawn from the distributions A.
+ n2: `Tensor` of numbers of samples to be drawn from the distributions B.
For each batch member `i`, of `K` total, drawing `n1[i]` samples
from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]`
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index cd6d749959..8d4914e16c 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -395,7 +395,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
ValueError: if `not distribution.is_scalar_batch`.
ValueError: if `not distribution.is_scalar_event`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[mix_loc, temperature]) as name:
if not scale or len(scale) < 2:
raise ValueError("Must specify list (or list-like object) of scale "
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 3465d66b30..a75b3f3df1 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -175,7 +175,7 @@ class VectorExponentialDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 2c31b01984..a7d4c55be9 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -175,7 +175,7 @@ class VectorExponentialLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 6a36018d6f..4a53e7a621 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -210,7 +210,7 @@ class VectorLaplaceDiag(
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name):
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 97e5c76d80..0566e04fec 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -191,7 +191,7 @@ class VectorLaplaceLinearOperator(
ValueError: if `scale` is unspecified.
TypeError: if not `scale.dtype.is_floating`
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
if scale is None:
raise ValueError("Missing required `scale` parameter.")
if not scale.dtype.is_floating:
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index ff5ca45257..bb33cd0762 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -163,7 +163,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
Raises:
ValueError: if at most `scale_identity_multiplier` is specified.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(
name,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 4742f75218..21f84dcbde 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -175,7 +175,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
scale_tril, scale_perturb_factor, scale_perturb_diag]
with ops.name_scope(name) as name:
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index f555867e7f..88d4280759 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -107,7 +107,7 @@ class _WishartLinearOperator(distribution.Distribution):
ValueError: if df < k, where scale operator event shape is
`(k, k)`
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
self._cholesky_input_output_matrices = cholesky_input_output_matrices
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[df, scale_operator]):
@@ -530,7 +530,7 @@ class WishartCholesky(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[scale]) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
@@ -646,7 +646,7 @@ class WishartFull(_WishartLinearOperator):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
with ops.name_scope("init", values=[scale]):
scale = ops.convert_to_tensor(scale)
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index c1fd9e0ed0..1d9371c7ac 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -7,6 +7,8 @@ py_library(
name = "examples_pip",
deps = [
"//tensorflow/contrib/eager/python/examples/gan:mnist",
+ "//tensorflow/contrib/eager/python/examples/l2hmc",
+ "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/resnet50",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD
new file mode 100644
index 0000000000..7bdf9053de
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD
@@ -0,0 +1,39 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_library(
+ name = "neural_nets",
+ srcs = ["neural_nets.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ ],
+)
+
+py_library(
+ name = "l2hmc",
+ srcs = ["l2hmc.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":neural_nets",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "l2hmc_test",
+ size = "large",
+ srcs = ["l2hmc_test.py"],
+ additional_deps = [
+ ":l2hmc",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
new file mode 100644
index 0000000000..98b4ce1b26
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -0,0 +1,382 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""L2HMC compatible with TensorFlow's eager execution.
+
+Reference [Generalizing Hamiltonian Monte Carlo with Neural
+Networks](https://arxiv.org/pdf/1711.09268.pdf)
+
+Code adapted from the released TensorFlow graph implementation by original
+authors https://github.com/brain-research/l2hmc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import numpy.random as npr
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets
+
+
+class Dynamics(tf.keras.Model):
+ """Dynamics engine of naive L2HMC sampler.
+
+ Args:
+ x_dim: dimensionality of observed data
+ loglikelihood_fn: log-likelihood function of conditional probability
+ n_steps: number of leapfrog steps within each transition
+ eps: initial value learnable scale of step size
+ """
+
+ def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1):
+ super(Dynamics, self).__init__()
+
+ self.x_dim = x_dim
+ self.potential = loglikelihood_fn
+ self.n_steps = n_steps
+
+ self._construct_time()
+ self._construct_masks()
+
+ self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
+ self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
+
+ self.eps = tfe.Variable(
+ initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
+
+ # TODO(lxuechen): Remove this after model.add_weight is in place
+ self.vars_not_in_layers = [self.eps]
+ self.vars_not_in_layers += self.position_fn.vars_not_in_layers
+ self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers
+
+ def apply_transition(self, position):
+ """Propose a new state and perform the accept or reject step."""
+
+ # Simulate dynamics both forward and backward;
+ # Use sampled Bernoulli masks to compute the actual solutions
+ position_f, momentum_f, accept_prob_f = self.transition_kernel(
+ position, forward=True)
+ position_b, momentum_b, accept_prob_b = self.transition_kernel(
+ position, forward=False)
+
+ # Decide direction uniformly
+ forward_mask = tf.cast(
+ tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32)
+ backward_mask = 1. - forward_mask
+
+ # Obtain proposed states
+ position_post = (
+ forward_mask[:, None] * position_f +
+ backward_mask[:, None] * position_b)
+ momentum_post = (
+ forward_mask[:, None] * momentum_f +
+ backward_mask[:, None] * momentum_b)
+
+ # Probability of accepting the proposed states
+ accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b
+
+ # Accept or reject step
+ accept_mask = tf.cast(
+ accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32)
+ reject_mask = 1. - accept_mask
+
+ # Samples after accept/reject step
+ position_out = (
+ accept_mask[:, None] * position_post + reject_mask[:, None] * position)
+
+ return position_post, momentum_post, accept_prob, position_out
+
+ def transition_kernel(self, position, forward=True):
+ """Transition kernel of augmented leapfrog integrator."""
+
+ lf_fn = self._forward_lf if forward else self._backward_lf
+
+ # Resample momentum
+ momentum = tf.random_normal(tf.shape(position))
+ position_post, momentum_post = position, momentum
+ sumlogdet = 0.
+ # Apply augmented leapfrog steps
+ for i in range(self.n_steps):
+ position_post, momentum_post, logdet = lf_fn(position_post, momentum_post,
+ i)
+ sumlogdet += logdet
+
+ accept_prob = self._compute_accept_prob(position, momentum, position_post,
+ momentum_post, sumlogdet)
+
+ return position_post, momentum_post, accept_prob
+
+ def _forward_lf(self, position, momentum, i):
+ """One forward augmented leapfrog step. See eq (5-6) in paper."""
+
+ t = self._get_time(i)
+ mask, mask_inv = self._get_mask(i)
+ sumlogdet = 0.
+
+ momentum, logdet = self._update_momentum_forward(position, momentum, t)
+ sumlogdet += logdet
+
+ position, logdet = self._update_position_forward(position, momentum, t,
+ mask)
+ sumlogdet += logdet
+
+ position, logdet = self._update_position_forward(position, momentum, t,
+ mask_inv)
+ sumlogdet += logdet
+
+ momentum, logdet = self._update_momentum_forward(position, momentum, t)
+ sumlogdet += logdet
+
+ return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+
+ def _backward_lf(self, position, momentum, i):
+ """One backward augmented leapfrog step. See Appendix A in paper."""
+
+ # Reversed index/sinusoidal time
+ t = self._get_time(self.n_steps - i - 1)
+ mask, mask_inv = self._get_mask(self.n_steps - i - 1)
+ sumlogdet = 0.
+
+ momentum, logdet = self._update_momentum_backward(position, momentum, t)
+ sumlogdet += logdet
+
+ position, logdet = self._update_position_backward(position, momentum, t,
+ mask)
+ sumlogdet += logdet
+
+ position, logdet = self._update_position_backward(position, momentum, t,
+ mask_inv)
+ sumlogdet += logdet
+
+ momentum, logdet = self._update_momentum_backward(position, momentum, t)
+ sumlogdet += logdet
+
+ return position, momentum, tf.reduce_sum(sumlogdet, axis=1)
+
+ def _update_momentum_forward(self, position, momentum, t):
+ """Update v in the forward leapfrog step."""
+
+ grad = self.grad_potential(position)
+ scale, translation, transformed = self.momentum_fn([position, grad, t])
+ scale *= .5 * self.eps
+ transformed *= self.eps
+ momentum = (
+ momentum * tf.exp(scale) -
+ .5 * self.eps * (tf.exp(transformed) * grad - translation))
+
+ return momentum, scale
+
+ def _update_position_forward(self, position, momentum, t, mask):
+ """Update x in the forward leapfrog step."""
+
+ mask_inv = 1. - mask
+ scale, translation, transformed = self.position_fn(
+ [momentum, mask * position, t])
+ scale *= self.eps
+ transformed *= self.eps
+ position = (
+ mask * position +
+ mask_inv * (position * tf.exp(scale) + self.eps *
+ (tf.exp(transformed) * momentum + translation)))
+
+ return position, mask_inv * scale
+
+ def _update_momentum_backward(self, position, momentum, t):
+ """Update v in the backward leapfrog step. Inverting the forward update."""
+
+ grad = self.grad_potential(position)
+ scale, translation, transformed = self.momentum_fn([position, grad, t])
+ scale *= -.5 * self.eps
+ transformed *= self.eps
+ momentum = (
+ tf.exp(scale) * (momentum + .5 * self.eps *
+ (tf.exp(transformed) * grad - translation)))
+
+ return momentum, scale
+
+ def _update_position_backward(self, position, momentum, t, mask):
+ """Update x in the backward leapfrog step. Inverting the forward update."""
+
+ mask_inv = 1. - mask
+ scale, translation, transformed = self.position_fn(
+ [momentum, mask_inv * position, t])
+ scale *= -self.eps
+ transformed *= self.eps
+ position = (
+ mask_inv * position + mask * tf.exp(scale) *
+ (position - self.eps * tf.exp(transformed) * momentum + translation))
+
+ return position, mask * scale
+
+ def _compute_accept_prob(self, position, momentum, position_post,
+ momentum_post, sumlogdet):
+ """Compute the prob of accepting the proposed state given old state."""
+
+ old_hamil = self.hamiltonian(position, momentum)
+ new_hamil = self.hamiltonian(position_post, momentum_post)
+
+ return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.))
+
+ def _construct_time(self):
+ """Convert leapfrog step index into sinusoidal time."""
+
+ self.ts = []
+ for i in range(self.n_steps):
+ t = tf.constant(
+ [
+ np.cos(2 * np.pi * i / self.n_steps),
+ np.sin(2 * np.pi * i / self.n_steps)
+ ],
+ dtype=tf.float32)
+ self.ts.append(t[None, :])
+
+ def _get_time(self, i):
+ """Get sinusoidal time for i-th augmented leapfrog step."""
+
+ return self.ts[i]
+
+ def _construct_masks(self):
+ """Construct different binary masks for different time steps."""
+
+ self.masks = []
+ for _ in range(self.n_steps):
+ idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2]
+ mask = np.zeros((self.x_dim,))
+ mask[idx] = 1.
+ mask = tf.constant(mask, dtype=tf.float32)
+ self.masks.append(mask[None, :])
+
+ def _get_mask(self, i):
+ """Get binary masks for i-th augmented leapfrog step."""
+
+ m = self.masks[i]
+ return m, 1. - m
+
+ def kinetic(self, v):
+ """Compute the kinetic energy."""
+
+ return .5 * tf.reduce_sum(v**2, axis=1)
+
+ def hamiltonian(self, position, momentum):
+ """Compute the overall Hamiltonian."""
+
+ return self.potential(position) + self.kinetic(momentum)
+
+ def grad_potential(self, position, check_numerics=True):
+ """Get gradient of potential function at current location."""
+
+ if not tf.executing_eagerly():
+ # TODO(lxuechen): Change this to tfe.gradients_function when it works
+ grad = tf.gradients(self.potential(position), position)[0]
+ else:
+ grad = tfe.gradients_function(self.potential)(position)[0]
+
+ if check_numerics:
+ return tf.check_numerics(grad, message="gradient of potential")
+
+ return grad
+
+
+# Defining loss and grads for training
+def compute_loss(x, dynamics, scale=.1, eps=1e-4):
+ """Compute loss defined in equation (8)."""
+
+ z = tf.random_normal(tf.shape(x))
+ x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
+ z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
+
+ # Add eps for numerical stability; following released impl
+ x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
+ z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
+
+ loss = tf.reduce_mean(
+ (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
+
+ return loss, x_out
+
+
+def loss_and_grads(x, dynamics):
+ """Obtain loss value and gradients."""
+
+ with tf.GradientTape() as tape:
+ loss_val, x_out = compute_loss(x, dynamics)
+
+ vars_ = dynamics.variables + dynamics.vars_not_in_layers
+ grads = tape.gradient(loss_val, vars_)
+
+ return loss_val, grads, x_out
+
+
+def warmup(dynamics, optimizer, n_iters=1, n_samples=200):
+ """Warmup optimization to reduce overhead."""
+
+ samples = tf.random_normal(
+ shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+ for _ in range(n_iters):
+ _, grads, samples = loss_and_grads(samples, dynamics)
+ vars_ = dynamics.variables + dynamics.vars_not_in_layers
+ optimizer.apply_gradients(zip(grads, vars_))
+
+
+def fit(dynamics,
+ optimizer,
+ n_samples=200,
+ n_iters=5000,
+ verbose=True,
+ logdir=None):
+ """Fit L2HMC sampler with given log-likelihood function."""
+
+ if logdir:
+ summary_writer = tf.contrib.summary.create_file_writer(logdir)
+
+ samples = tf.random_normal(
+ shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+ tf.train.get_or_create_global_step()
+ for i in range(n_iters):
+ loss, grads, samples = loss_and_grads(samples, dynamics)
+ # TODO(lxuechen): Proper learning rate decay
+ grads_ = [grad * .96**(i // 1000) for grad in grads]
+ vars_ = dynamics.variables + dynamics.vars_not_in_layers
+ optimizer.apply_gradients(
+ zip(grads_, vars_), global_step=tf.train.get_global_step())
+
+ if verbose:
+ print("Iteration %d: loss %.4f" % (i, loss))
+
+ if logdir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("loss", loss)
+
+
+def get_scg_energy_fn():
+ """Get energy function for 2d strongly correlated Gaussian."""
+
+ # Avoid recreating tf constants on each invocation of gradients
+ mu = tf.constant([0., 0.])
+ sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+ sigma_inv = tf.matrix_inverse(sigma)
+
+ def energy(x):
+ """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
+
+ xmmu = x - mu
+ return .5 * tf.diag_part(
+ tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
+
+ return energy
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
new file mode 100644
index 0000000000..522a7c9380
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -0,0 +1,162 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy.random as npr
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc
+
+
+def get_default_hparams():
+ return tf.contrib.training.HParams(
+ x_dim=2,
+ n_samples=200,
+ n_steps=10,
+ eps=.1,
+ n_iters=5,
+ learning_rate=.001,
+ n_warmup_iters=1)
+
+
+class L2hmcTest(tf.test.TestCase):
+ """Unit tests for l2hmc in both eager and graph mode."""
+
+ def testComputeLoss(self):
+ """Testing function l2hmc.compute_loss in both graph and eager mode."""
+
+ # Eager mode testing
+ hparams = get_default_hparams()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim])
+ loss, x_out = l2hmc.compute_loss(samples, dynamics)
+
+ # Check shape and numerical stability
+ self.assertEqual(x_out.shape, samples.shape)
+ self.assertEqual(loss.shape, [])
+ self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5)
+
+ # Graph mode testing
+ with tf.Graph().as_default():
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
+ loss, x_out = l2hmc.compute_loss(x, dynamics)
+ samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples})
+
+ # Check shape and numerical stability
+ self.assertEqual(x_out_np.shape, samples.shape)
+ self.assertEqual(loss_np.shape, ())
+ self.assertAllClose(loss_np, loss_np, rtol=1e-5)
+
+
+class L2hmcBenchmark(tf.test.Benchmark):
+ """Eager and graph benchmarks for l2hmc."""
+
+ def benchmarkEagerL2hmc(self):
+ """Benchmark Eager performance."""
+
+ hparams = get_default_hparams()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ # TODO(lxuechen): Add learning rate decay
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+
+ # Warmup to reduce initialization effect when timing
+ l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters)
+
+ # Time
+ start_time = time.time()
+ l2hmc.fit(
+ dynamics,
+ optimizer,
+ n_samples=hparams.n_samples,
+ n_iters=hparams.n_iters)
+ wall_time = time.time() - start_time
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
+
+ def benchmarkGraphL2hmc(self):
+ """Benchmark Graph performance."""
+
+ hparams = get_default_hparams()
+ with tf.Graph().as_default():
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
+ loss, x_out = l2hmc.compute_loss(x, dynamics)
+
+ global_step = tf.Variable(0., name="global_step", trainable=False)
+ learning_rate = tf.train.exponential_decay(
+ hparams.learning_rate, global_step, 1000, 0.96, staircase=True)
+ optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
+ train_op = optimizer.minimize(loss, global_step=global_step)
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+
+ # Warmup to reduce initialization effect when timing
+ samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
+ for _ in range(hparams.n_warmup_iters):
+ samples, _, _, _ = sess.run(
+ [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+
+ # Time
+ start_time = time.time()
+ for _ in range(hparams.n_iters):
+ samples, _, _, _ = sess.run(
+ [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+ wall_time = time.time() - start_time
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="graph_train_%s" % ("gpu"
+ if tf.test.is_gpu_available() else "cpu"),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
new file mode 100644
index 0000000000..c902e1f1f4
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution.
+
+Reference [Generalizing Hamiltonian Monte Carlo with Neural
+Networks](https://arxiv.org/pdf/1711.09268.pdf)
+
+Code adapted from the released TensorFlow graph implementation by original
+authors https://github.com/brain-research/l2hmc.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import tensorflow.contrib.eager as tfe
+
+
+class GenericNet(tf.keras.Model):
+ """Generic neural net with different initialization scale based on input.
+
+ Args:
+ x_dim: dimensionality of observed data
+ factor: factor of variance scaling initializer
+ n_hidden: number of hidden units
+ """
+
+ def __init__(self, x_dim, factor, n_hidden=10):
+ super(GenericNet, self).__init__()
+
+ self.v_layer = _custom_dense(n_hidden, 1. / 3.)
+ self.x_layer = _custom_dense(n_hidden, factor / 3.)
+ self.t_layer = _custom_dense(n_hidden, 1. / 3.)
+ self.h_layer = _custom_dense(n_hidden)
+
+ # Scale
+ self.scale_layer = _custom_dense(x_dim, .001)
+ self.coeff_scale = tfe.Variable(
+ initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True)
+ # Translation
+ self.translation_layer = _custom_dense(x_dim, factor=.001)
+ # Transformation
+ self.transformation_layer = _custom_dense(x_dim, .001)
+ self.coeff_transformation = tfe.Variable(
+ initial_value=tf.zeros([1, x_dim]),
+ name='coeff_transformation',
+ trainable=True)
+ # TODO(lxuechen): Remove this after model.add_weight is in place
+ self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation]
+
+ def call(self, inputs):
+ v, x, t = inputs
+ h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t)
+ h = tf.nn.relu(h)
+ h = self.h_layer(h)
+ h = tf.nn.relu(h)
+ scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale)
+ translation = self.translation_layer(h)
+ transformation = (
+ tf.nn.tanh(self.transformation_layer(h)) * tf.exp(
+ self.coeff_transformation))
+
+ return scale, translation, transformation
+
+
+def _custom_dense(units, factor=1.):
+ """Custom dense layer with specified weight initialization."""
+
+ return tf.keras.layers.Dense(
+ units=units,
+ use_bias=True,
+ kernel_initializer=tf.contrib.layers.variance_scaling_initializer(
+ factor=factor * 2., mode='FAN_IN', uniform=False),
+ bias_initializer=tf.constant_initializer(0., dtype=tf.float32))
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
index 2259c20741..099b712fc0 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py
@@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
mse = lambda xs, ys: mean_square_loss(model, xs, ys)
loss_and_grads = tfe.implicit_value_and_gradients(mse)
- tf.train.get_or_create_global_step()
if logdir:
# Support for TensorBoard summaries. Once training has started, use:
# tensorboard --logdir=<logdir>
@@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None):
if verbose:
print("Iteration %d: loss = %s" % (i, loss.numpy()))
- optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())
+ optimizer.apply_gradients(grads)
if logdir:
with summary_writer.as_default():
with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("loss", loss)
+ tf.contrib.summary.scalar("loss", loss, step=i)
+ tf.contrib.summary.scalar("step", i, step=i)
def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
index 9fd2d8d125..51d10a7784 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb
@@ -1,495 +1,429 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "Eager Execution Tutorial: Basics",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [
- {
- "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg",
- "timestamp": 1504118841551
- }
- ]
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "U9i2Dsh-ziXr",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "U9i2Dsh-ziXr"
},
- "cell_type": "markdown",
"source": [
- "# Eager Execution Tutorial: Basics\n",
+ "# An introduction to TensorFlow\n",
"\n",
- "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n",
+ "This is an introductory tutorial for using TensorFlow. It will cover:\n",
"\n",
"* Importing required packages\n",
- "* Enabling eager execution\n",
- "* Creating and using TensorFlow Tensors and Variables\n",
- "* Using TensorFlow interactively\n",
- "* Using GPUs with eager execution enabled\n",
- "\n",
- "This notebook does *not* cover modeling topics, such as gradients."
+ "* Creating and using Tensors\n",
+ "* Using GPU acceleration\n"
]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "z1JcS5iBXMRO",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "z1JcS5iBXMRO"
},
- "cell_type": "markdown",
"source": [
- "# Step 1: Import Eager\n",
+ "## Import TensorFlow\n",
"\n",
- "The key imports for eager execution are the following:"
+ "To get started, import the `tensorflow` module and enable eager execution.\n",
+ "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "RlIWhyeLoYnG",
- "colab_type": "code",
+ "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
- "cellView": "code"
+ "colab_type": "code",
+ "id": "RlIWhyeLoYnG"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
- "# Import TensorFlow.\n",
"import tensorflow as tf\n",
"\n",
- "# Import TensorFlow eager execution support (subject to future changes).\n",
- "tfe = tf.contrib.eager"
- ],
- "execution_count": 0,
- "outputs": []
+ "tf.enable_eager_execution()"
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "H9UySOPLXdaw",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "H9UySOPLXdaw"
},
- "cell_type": "markdown",
"source": [
- "# Step 2: Enable eager execution\n",
+ "## Tensors\n",
"\n",
- "All future TensorFlow calls will execute the\n",
- "underlying TensorFlow ops immediately:"
+ "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n"
]
},
{
- "metadata": {
- "id": "WPTUfGq6kJ5w",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
"cell_type": "code",
- "source": [
- "tf.enable_eager_execution()"
- ],
"execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "twBfWd5xyu_d",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Step 3: Interactively Use TensorFlow!\n",
- "\n",
- "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n",
- "\n",
- "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors."
- ]
- },
- {
"metadata": {
- "id": "ngUe237Wt48W",
- "colab_type": "code",
+ "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
- }
+ },
+ "height": 125
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 320,
+ "status": "ok",
+ "timestamp": 1526420535530,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
},
- "cellView": "code"
+ "id": "ngUe237Wt48W",
+ "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e"
},
- "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(3, shape=(), dtype=int32)\n",
+ "tf.Tensor([4 6], shape=(2,), dtype=int32)\n",
+ "tf.Tensor(25, shape=(), dtype=int32)\n",
+ "tf.Tensor(6, shape=(), dtype=int32)\n",
+ "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n",
+ "tf.Tensor(13, shape=(), dtype=int32)\n"
+ ]
+ }
+ ],
"source": [
"print(tf.add(1, 2))\n",
"print(tf.add([1, 2], [3, 4]))\n",
"print(tf.square(5))\n",
"print(tf.reduce_sum([1, 2, 3]))\n",
"print(tf.encode_base64(\"hello world\"))\n",
- "print(\"\")\n",
"\n",
- "x = tf.constant(2)\n",
- "y = tf.constant(3)\n",
- "print(x * y + 1)\n",
- "\n",
- "# Most TensorFlow ops are directly usable with eager execution, giving\n",
- "# results immediately.\n",
- "print(tf.contrib.signal.hamming_window(x * y + 1))"
- ],
- "execution_count": 0,
- "outputs": []
+ "# Operator overloading is also supported\n",
+ "print(tf.square(2) + tf.square(3))"
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "IDY4WsYRhP81",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "IDY4WsYRhP81"
},
- "cell_type": "markdown",
"source": [
- "Numpy arrays are supported, too:"
+ "Each Tensor has a shape and a datatype"
]
},
{
- "metadata": {
- "id": "lCUWzso6mbqR",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
"cell_type": "code",
- "source": [
- "import numpy as np\n",
- "\n",
- "ones = np.ones([3, 3])\n",
- "\n",
- "print(\"numpy 3x3 matrix of 1s:\")\n",
- "print(ones)\n",
- "print(\"\")\n",
- "\n",
- "print(\"Multiplied by 42:\")\n",
- "print(tf.multiply(ones, 42))"
- ],
"execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "PBNP8yTRfu_X",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Step 4: Define and Print TensorFlow Variables\n",
- "\n",
- "To define TensorFlow variables, use the `get_variable()` function as follows:"
- ]
- },
- {
"metadata": {
- "id": "3Twf_Rw-gQFM",
- "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
- }
+ },
+ "height": 53
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 215,
+ "status": "ok",
+ "timestamp": 1526420538162,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
},
- "cellView": "code"
+ "id": "srYWH1MdJNG7",
+ "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561"
},
- "cell_type": "code",
- "source": [
- "x = tfe.Variable(0.)"
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1, 2)\n",
+ "\u003cdtype: 'int32'\u003e\n"
+ ]
+ }
],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "45G7094TxsMb",
- "colab_type": "text"
- },
- "cell_type": "markdown",
"source": [
- "## Printing TensorFlow Variables"
+ "x = tf.matmul([[1]], [[2, 3]])\n",
+ "print(x.shape)\n",
+ "print(x.dtype)"
]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "UJBJeZ5XxuwA",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
+ "colab_type": "text",
+ "id": "eBPw8e8vrsom"
},
- "cell_type": "code",
"source": [
- "# This does NOT print the Variable's actual value:\n",
- "print(\"Printing a TensorFlow Variable:\")\n",
- "print(x)\n",
- "print(\"\")\n",
+ "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n",
"\n",
- "\n",
- "print(\"Printing a TensorFlow Variable's value as a numpy array:\")\n",
- "print(x.numpy())"
- ],
- "execution_count": 0,
- "outputs": []
+ "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n",
+ "2. Tensors are immutable."
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "2njjWHcTpBEn",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "Dwi1tdW3JBw6"
},
- "cell_type": "markdown",
"source": [
- "## Changing a TensorFlow Variable's value\n",
+ "### NumPy Compatibility\n",
"\n",
- "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:"
+ "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n",
+ "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n",
+ "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n",
+ "\n",
+ "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n",
+ "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory."
]
},
{
- "metadata": {
- "id": "v3wr6Erbo_hB",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
"cell_type": "code",
- "source": [
- "x.assign(42)\n",
- "print(x)\n",
- "\n",
- "x.assign_add(3)\n",
- "print(x)"
- ],
"execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "uhtynjHVpTB5",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Use a Variable just like any other Tensor"
- ]
- },
- {
- "metadata": {
- "id": "7PbktdnHoehR",
- "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
- }
- }
+ },
+ "height": 251
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 238,
+ "status": "ok",
+ "timestamp": 1526420540562,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "lCUWzso6mbqR",
+ "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46"
},
- "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "TensorFlow operations convert numpy arrays to Tensors automatically\n",
+ "tf.Tensor(\n",
+ "[[ 42. 42. 42.]\n",
+ " [ 42. 42. 42.]\n",
+ " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n",
+ "And NumPy operations convert Tensors to numpy arrays automatically\n",
+ "[[ 43. 43. 43.]\n",
+ " [ 43. 43. 43.]\n",
+ " [ 43. 43. 43.]]\n",
+ "The .numpy() method explicitly converts a Tensor to a numpy array\n",
+ "[[ 42. 42. 42.]\n",
+ " [ 42. 42. 42.]\n",
+ " [ 42. 42. 42.]]\n"
+ ]
+ }
+ ],
"source": [
- "print(x + 3)\n",
+ "import numpy as np\n",
"\n",
- "# This code will broadcast the value across the list of numbers:\n",
- "print(x * [1, 2, 4])"
- ],
- "execution_count": 0,
- "outputs": []
+ "ndarray = np.ones([3, 3])\n",
+ "\n",
+ "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n",
+ "tensor = tf.multiply(ndarray, 42)\n",
+ "print(tensor)\n",
+ "\n",
+ "\n",
+ "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n",
+ "print(np.add(tensor, 1))\n",
+ "\n",
+ "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n",
+ "print(tensor.numpy())"
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "GVChqwlwy1SI",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "PBNP8yTRfu_X"
},
- "cell_type": "markdown",
"source": [
- "# Step 5: Debug Errors with Instant Feedback\n",
+ "## GPU acceleration\n",
"\n",
- "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n",
- "\n",
- "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n",
- "one being legal and the other being illegal, leading to a runtime error that is\n",
- "raised immediately."
+ "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:"
]
},
{
- "metadata": {
- "id": "23ap04N0v4k0",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
"cell_type": "code",
- "source": [
- "vector = tf.constant([10.0, 20.0, 30.0, 40.0])"
- ],
"execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "FCUMsIYxxRRa",
- "colab_type": "code",
+ "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
- }
+ },
+ "height": 53
},
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n",
- "# arguments) are within the bound of `vector`.\n",
- "print(tf.slice(vector, [1], [3]))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "T8me2oCNxpFp",
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
+ "executionInfo": {
+ "elapsed": 340,
+ "status": "ok",
+ "timestamp": 1526420543562,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
},
- "cellView": "code"
+ "id": "3Twf_Rw-gQFM",
+ "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b"
},
- "cell_type": "code",
- "source": [
- "# The following does NOT work, because the value of `size` (the 3rd\n",
- "# argument) causes the indices to go out of the bounds of `vector`. The\n",
- "# error is raised immediately.\n",
- "try:\n",
- " print(tf.slice(vector, [1], [4]))\n",
- "except tf.OpError as e:\n",
- " print(\"Caught error: %s\" % e)"
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Is there a GPU available: False\n",
+ "Is the Tensor on GPU #0: False\n"
+ ]
+ }
],
- "execution_count": 0,
- "outputs": []
+ "source": [
+ "x = tf.random_uniform([3, 3])\n",
+ "\n",
+ "print(\"Is there a GPU available: \"),\n",
+ "print(tf.test.is_gpu_available())\n",
+ "\n",
+ "print(\"Is the Tensor on GPU #0: \"),\n",
+ "print(x.device.endswith('GPU:0'))"
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "irxJhAgar84v",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "vpgYzgVXW2Ud"
},
- "cell_type": "markdown",
"source": [
- "# Step 6: Using the GPU\n",
- "\n",
- "You can explicitly place Tensors on the GPU by calling a Tensor's `.gpu()` method. The `.device` property tells you whether the Tensor is backed by CPU or GPU memory.\n",
+ "### Device Names\n",
"\n",
- "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster."
+ "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host."
]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "7J4N9baqaKCL",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "ZWZQCimzuqyP"
},
- "cell_type": "code",
"source": [
- "# Create some Tensors\n",
- "SIZE = 1000\n",
- "tensor = tf.random_normal([SIZE, SIZE])\n",
- "print(tensor.device)\n",
"\n",
"\n",
- "if tf.test.is_gpu_available():\n",
- " gpu_tensor = tensor.gpu()\n",
- " cpu_tensor = tensor.cpu()\n",
- "else:\n",
- " print(\"GPU not available.\")\n",
- " cpu_tensor = tensor"
- ],
- "execution_count": 0,
- "outputs": []
+ "### Explicit Device Placement\n",
+ "\n",
+ "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:"
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "4E-2n7VbzY1n",
- "colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
- }
- }
+ },
+ "height": 53
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 1762,
+ "status": "ok",
+ "timestamp": 1526420547562,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "RjkNZTuauy-Q",
+ "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd"
},
- "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "On CPU:\n",
+ "10 loops, best of 3: 35.8 ms per loop\n"
+ ]
+ }
+ ],
"source": [
- "# Time a CPU-based matrix multiplication\n",
+ "def time_matmul(x):\n",
+ " %timeit tf.matmul(x, x)\n",
"\n",
- "print(\"Time to conduct matmul on CPU:\")\n",
- "%time tf.matmul(cpu_tensor, cpu_tensor)"
- ],
- "execution_count": 0,
- "outputs": []
+ "# Force execution on CPU\n",
+ "print(\"On CPU:\")\n",
+ "with tf.device(\"CPU:0\"):\n",
+ " x = tf.random_uniform([1000, 1000])\n",
+ " assert x.device.endswith(\"CPU:0\")\n",
+ " time_matmul(x)\n",
+ "\n",
+ "# Force execution on GPU #0 if available\n",
+ "if tf.test.is_gpu_available():\n",
+ " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n",
+ " x = tf.random_uniform([1000, 1000])\n",
+ " assert x.device.endswith(\"GPU:0\")\n",
+ " time_matmul(x)"
+ ]
},
{
+ "cell_type": "markdown",
"metadata": {
- "id": "vbSFW-T5zhZF",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "YEOJTNiOvnpQ"
},
- "cell_type": "code",
"source": [
- "# Time GPU-based matrix multiplications.\n",
+ "## Next Steps\n",
"\n",
- "if tf.test.is_gpu_available():\n",
- " # First use of the GPU will be slow:\n",
- " print(\"Time to conduct first matmul on GPU:\")\n",
- " %time tf.matmul(gpu_tensor, gpu_tensor)\n",
- " print()\n",
- "\n",
- " # Subsequent uses are much faster:\n",
- " print(\"Time to conduct second matmul on GPU:\")\n",
- " %time tf.matmul(gpu_tensor, gpu_tensor)"
- ],
- "execution_count": 0,
- "outputs": []
+ "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n",
+ "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "name": "TensorFlow: An introduction",
+ "provenance": [],
+ "version": "0.3.2",
+ "views": {}
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
index 1e65b27bc8..9c1af9c208 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb
@@ -7,12 +7,9 @@
"id": "vDJ4XzMqodTy"
},
"source": [
- "# Eager Execution: Working with Gradients\n",
+ "# Automatic Differentiation\n",
"\n",
- "This notebook demonstrates:\n",
- "\n",
- "* How to get gradients using TensorFlow's eager execution capabilities\n",
- "* How to apply the gradients so you can update your variables"
+ "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models."
]
},
{
@@ -22,7 +19,7 @@
"id": "GQJysDM__Qb0"
},
"source": [
- "# Setup: Import eager and enable eager execution.\n"
+ "## Setup\n"
]
},
{
@@ -40,12 +37,10 @@
},
"outputs": [],
"source": [
- "# Import TensorFlow.\n",
"import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
"\n",
- "\n",
- "# Enable eager execution.\n",
- "tf.enable_eager_execution()"
+ "tfe = tf.contrib.eager # Shorthand for some symbols"
]
},
{
@@ -55,28 +50,15 @@
"id": "1CLWJl0QliB0"
},
"source": [
- "# Fitting a Simple Linear Model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "-39gouo7mtgu"
- },
- "source": [
- "## Step 1: Synthesize some data\n",
- "\n",
- "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n",
+ "## Derivatives of a function\n",
"\n",
- "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model."
+ "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
@@ -84,105 +66,53 @@
}
},
"colab_type": "code",
- "id": "rQsdCg9PfIL-"
+ "id": "9FViq92UX7P8"
},
"outputs": [],
"source": [
- "# The constants we'll try to fit our variables to:\n",
- "true_w = 3\n",
- "true_b = 2\n",
- "\n",
- "NUM_EXAMPLES = 1000\n",
+ "from math import pi\n",
"\n",
- "# Our inputs:\n",
- "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n",
+ "def f(x):\n",
+ " return tf.square(tf.sin(x))\n",
"\n",
- "# Our labels, with noise:\n",
- "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n",
- "labels = inputs * true_w + true_b + noise"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "cellView": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "base_uri": "https://localhost:8080/",
- "height": 347
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 374,
- "status": "ok",
- "timestamp": 1525154227149,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "O4lsC4ckAcar",
- "outputId": "f8becb3f-498b-4cb7-9ef3-608a68cb65d0"
- },
- "outputs": [
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAFKCAYAAAAnj5dkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xt8VPWdP/7X3M5MkpkkM8mEAAER\nQoICgUBALkUEQ7FucekDEeWL3VZXu121dler39pu1Vbb77b+2m1/3277qNXa2kUptGttt/tDEWqp\nyDWBiC6ES8slXDJJJpfJ3C+/P8JM5nLOmTOTmWQm83r+RebMnJyTAO/z+Xzen/dbFQqFQiAiIqKc\noR7rCyAiIqJYDM5EREQ5hsGZiIgoxzA4ExER5RgGZyIiohzD4ExERJRjtGN9AWE220DWzm02F8Nu\nd2bt/LmukO+/kO8d4P0X8v0X8r0D+XH/VqtJ8lhBjJy1Ws1YX8KYKuT7L+R7B3j/hXz/hXzvQP7f\nf0EEZyIionzC4ExERJRjGJyJiIhyDIMzERFRjmFwJiIiyjEMzkRERDmGwZmIiCjHMDgTERHlGAZn\nIiKiJDy+ADrtTnh8gVH5fjlTvpOIiCjXBIJBbNt9Gq3tNvT0e2Ap1aOxzopNq2uhUWdvfMvgTERE\nJGHb7tPYdfhi5Ovufk/k683NdVn7vpzWJiIiEuHxBdDabhM91treldUpbgZnIiIiEX0OD3r6PaLH\n7ANu9DnEj2UCgzMREZGIMqMellK96DGzyYAyo/ixTGBwJiIiEqHXadBYZxU91lhXCb0ue20pmRBG\nREQkYdPqWgBDa8z2ATfMJgMa6yojr2cLgzMREZEEjVqNzc112LByBvocHpQZ9VkdMYcxOBMRESWh\n12lQZS4ete/HNWciIsqa0a6sNV5w5ExERBk3VpW1xgsGZyIiyrixqqw1XvDxhYiIMmosK2uNFwzO\nRESUUWNZWWu8YHAmIqKMGsvKWuMFgzMREWXUWFbWGi+YEEZERBk3VpW1xgsGZyIiyrixqqw1XjA4\nExFR1ox2Za3xgmvORESUMawIlhmKRs7t7e34x3/8R3zmM5/Bli1bcPnyZTzxxBMIBAKwWq34zne+\nA0EQYj7zzW9+E8eOHYNKpcJTTz2FhoaGrNwAERGNPVYEy6ykPzGn04lvfOMbWLp0aeS1H/zgB9i8\neTO2bt2K6667Djt27Ij5zMGDB3Hu3Dls27YNzz//PJ5//vnMXzkREeWMcEWw7n4PQhiuCLZt9+mx\nvrS8lDQ4C4KAF198EVVVVZHXDhw4gFtvvRUAsGrVKrz//vsxn3n//ffR3NwMAJgxYwb6+vrgcDgy\ned1ERJQjlFQE43R3apJOa2u1Wmi1sW9zuVyRaeyKigrYbLG/lK6uLsyePTvytcVigc1mg9FozMQ1\nExFRCjy+QFYzppNVBHt150mcPG/ndHcKRpytHQqFMvIes7kYWm320uytVlPWzp0PCvn+C/neAd5/\nId+/xVKCl3/3IfYfvwxbrwvW8iIsmTMR962bDY0mc4HRVFYEq7kInXZXwjG9oMW+41ciX4enu4uL\nBDywfm7GrkFMPv/u0wrOxcXFcLvdMBgMuHr1asyUNwBUVVWhq6sr8nVnZyesVvFqMWF2uzOdS1HE\najXBZhvI2vlzXSHffyHfO8D7L+T7t1pN+L+/ao3pDNVpd+HNvWfhdHkz3hmqYUZFzPcKC4WCou9/\n79glfGLxlKztfc6H373cw0Naj07Lli3Dzp07AQBvvfUWVqxYEXN8+fLlkeMffvghqqqqOKVNRDSK\n3F7/qHaG2rS6Fs1NNagoNUCtAipKDVg+pxpur3hwZgMMeUlHzsePH8e//uu/oqOjA1qtFjt37sQL\nL7yA//2//ze2bduGSZMmYf369QCAf/qnf8K3vvUtLFiwALNnz8bdd98NlUqFp59+Ous3QkREw+z9\nyTtDZaI4SPR6dnxFMAA4cd6ObpHrYAMMeUmD85w5c/Dqq68mvP6zn/0s4bXvfe97kT8//vjjI7w0\nIiJKl7l0qDNUssCYLFlM6rjcvubooN9YZxWd7mYDDHks30lENA54fAHYel1AKASruRhWQSsbGLUa\nFbbuapcsGpKsqEh4X3NYONELQMx6tlgDjIYZFqxqnAyPL8AALYHBmYgojwWCQbz+zim898EVuL1D\n68gGQY3mxdfhzlumAxDvDJUsuMod37Byhsx6tg0bVs6IBN3oBhg9/W7sOnIRbae78MfWS9xWJYPB\nmYgoj23bfRrvHOmIec3tDeL3f/4L3G6faGeoZEVD1i2bJnv85nmTJNezu/s9eHXnSXz29lkxAVev\n02BPawf2tHTEvFdstE1sfEFElLfkgiwAtJy0RaaOq8zFkdFssqIhFzsdsscRCsFSKp3Mte/4lYSy\nnUqqiNEwBmciojE0krKWckEWAOwDHtHtSmVGvWRwNZsMqKkySh4XdBpYyorQWCdfuyI+4CZ7IOC2\nqlic1iYiGgPpdHGKz5wOB1mxjGwAMJv0otuV9DqNbLKYqViQPO72BvDG3rPYtLoWLrcf70VV/4oW\nv11L7lq5rSoRgzMR0RhQmu0MyAdyqSAKAAvqrZLZ0GJZ1OFkMQBYv+J6/LntciTJLFprexc2rJyB\nLWvr8T/netAz4E14T3zATfZAwKztWAzORESjLNn6a3S2MyAfyDetrkUoFIrL1tagefFU/O2y6ySv\nITqLWmwfs8Ppg0ckMAOxo+IF9VWKA26yBwIaxuBMRDTKlKy/hqeDlQTy/7WmHnfeUhuzz7lmUrmi\n2tLhZLF4SqehUwm4yR4IaBiDMxHRKEtl/VVpINfrNKixKu9hkKwymNJp6HQCrtQDAQ1jcCYiGmWp\nrL9mOpHK6fFh69uncOJcD+wDXlhK9WiorUTzwhpYSg0x3zuVUTEDbmYxOBMRjQGlgS9TiVThpLL4\nJK/ufg/2tAwVB6mIyxgXGxUDQHefO/JnTk9nB4MzEdEYSGU6eP2K6+F0+3HinB29Do9oIE82TR2f\nVCZGKmNcr9OgoswQyRjv7vfAIKgBqODxBliGMwsYnImIxlB4v7LSzk9LZ1fjnjV1KNZrJd/TWGfF\nw3c1Rs6TrJJYPCUZ49F9mlmGM/MYnImIRlH0CFerUaXc+em941fg8vrxd7fNgqlYkNxmVVwkYP3y\naQCSVxKLl0rGeDSxoE7pYXAmIsqAZNPKYiPcYoMOFzodkfco7fzU0t6F1vY/Y7K1BE63T/Q9+49f\nxicWT1FUSSye2aSH1xeI1OVWGtzjgzqlj8GZiGgElJbhFBvhSgXLZJ2fACAE4KJtUPJ4V68LZzv6\nMH1ymWxSmZhBtw9Pv3woci/rV0xXFNxZhjNzGJyJiEZASRnOVNd8ozs/KR3tJlAB33n9aCQDO7G3\nsx71U83QaVU4ftYO+4Abgk4DtzcQWU+OvhclwZ1lODOHwZmIKE1Ky3CmuuZbbtQDKhUaaitj+h+n\nIngtXyv+YUEsO9zjC8DW68K//eqoZC3tZ+9fhEAgiNZTXeh1eGEQhj7r9QVYhjMLGJyJiNKktHpX\nqmu+To8fT790EGaTgImWYlzucY74WsMPC2L0Og0ErRp2kQYWwNC9bH37FE6et6PP4YXZqMf8ukps\nWDkdDqeP+5yzgMGZiChNSqt3KV3zFbQqeP2hyOh1qNuTF3qdGh5fUPazydgH3Hh150mcPG8XXRuX\nuxdBp8G+qNaQdsdQ4RKNWsWtU1nC3eJERCny+ALotA+NZhvrrKLviV9/3bS6FsvmVMue1+sPib6u\nUqV5oVEEnRr7jl9Bd78HIQxPd7/+zikAww8Q4sSvq7W9Cx6feOcqGhmOnImIFBLLzJ4/sxKrF07G\nsVPdsmU4NWo17l1bj5Pn7Sknebm9QSyfU40j7TbRNWElfH7xkfd7H1zBnbfUQq/TYNPq2si6cp/D\nC0upAbOmluO9qFFzNG6dyh4GZyIihcQys9850oHmpho898BNSctwprqlKUytGirh+T/netIKztWW\nIlzpcYkec3uHksEmVhRj2+7TaDvTjT6HF+VGPRpqK7Bh5QyckHig4Nap7OG0NhGNC+Gp5mxNsybL\nzAYQad0oJRAMIhQKRTKdlQqGgE67SzJhKxmXxy//hlAo8uARnvYOryu/sfes4ql7yhyOnIkor8kV\nAckkpZnZcrbtPo13jqS+Ncpi0qOmypj2vue+QR/0WjU8IlPbBmGogpjcg8ez9y+K/DlZ60jKDAZn\nIsprckVAHr1nYca+z0j7KqdaiCRaSZEOGo0K9VPNMVnTSpUbBcybWYl3Wy8lHFs2txouj1/2wcPh\n9CnuoEWZweBMRHkr2VSz25tkOjcFep0GDTMqsEckwCmZ3k21EEm0C50OPP7DffB4AzAIGoRCoZS2\nVjXOrMTmNXXQadRoOWlDz4AHZSU6LKiz4p5bZ8IfkK5GFr8ljMlfoyPt4Lx9+3a8+eabka+PHz+O\n1tbWyNezZ8/GggULIl+/8sor0Gj4pEVEmZNsqtne78nICCQ8dd52phvAUIJWMDQ03bygXnoKPboZ\nRqqFSOKFE8FSTQibaCnG5jV10KjVQ9nYwRCOtneh1+FB25luaDSnsWl1rWSiGteVx0baf283btyI\njRs3AgAOHjyI//7v/445bjQa8eqrr47s6oiIZCSbajaX6jHQJ56lrEQ4uO48dCGmjGbw2rbfeTMr\nRYtwSK2Dz59Zmdaas5jwA4Icg6DGV/6uKdKAY9vu0zH3Eb0EEH7A4LpybsjItPYPf/hDvPDCC5k4\nFRGRYnJbkxrrKmEQtBgQ+Vwq7R27+z1QSxQBaTvdDc+qQMI5tr7dHjP9HQ6CtyyYhBpriWw3KRWk\nSn7EShaYAeBjDZNQrB/6b37A6cXh/+kUfV+4tCfXlXPHiINzW1sbJk6cCKs1NtXe6/XiscceQ0dH\nB9auXYvPfvazsucxm4uh1WbvL4LVasraufNBId9/Id87MP7v/+G7GlFcJGD/8cuw2V0wl+qxZM5E\nPLh+LoDY+w8Egnj5dx8OvbfXBWt5EZbMmYj71s2GRjO8s/TFNz6ICfhSgdA+4IZG0MFaWRI5/0/e\n+ADvHktclwaAAx9ehcsjPy2tJDADgLXcgAX1VXj70PlIk4toRXotHlg/F3pBi5++eRxvHzwHj1d8\nnTr+PmoUXkOuy+e/+yMOzjt27MCnPvWphNefeOIJ3HHHHVCpVNiyZQuampowd+5cyfPY7SMv7C7F\najXBZhN7fi4MhXz/hXzvQOHc/7qlUzEw6MFRXxfs/R4cOH4ZXq8fD9/ViJ6e4VHq1l3tMUG30+7C\nm3vPwunyxrR3fO+Ysqlns8mAgNcX+RnHnz9essAMDK1jz5tZibbT3TFtHOPNq63E8jnVeOvAedHz\neLx+/OWCHbuOXExa9CT+PsaDfPi7L/fwMOLgfODAAXz1q19NeP2ee+6J/HnJkiVob2+XDc5EROmS\nWkstLhKwfvk0ANlp7xidLDWSrVLRFtRbsbm5Dp5VQ1PvxmIBb+w9G7MWPG9mBUKhEP7tV0clR9qV\n5UUo0mvRclJ8KlvqPig3jKhC2NWrV1FSUgJBEGJeP3v2LB577DGEQiH4/X60tLRg5syZI7pQIiIx\nckFx//HLkYphSoqIAMNJZmLUqqEmFBWlBjQ31WDT6tpIZTKb3Zk0qMtVBrOY9JFzAsPblor1Wmxu\nrsNzD9yEbz64BF/7TBM8ngDeOdJxrWuVuCVzJg7tX05SVWzZnGomfeWgEY2cbTYbLBZL5Ouf/OQn\nWLRoERobG1FdXY0777wTarUaq1evRkNDw4gvlojGv2TJWvHkgq7N7sLZjj5Mn1yWkfaOK+dPwtrF\nU1Fm1EOrUSVkZAs6FTw+8bGsXqfGktlV+GPr5YRjy+dUY8vaetn71WpU2HXkIlpOdsoG3IprmeH3\nrZuNy1f7YTEJku+3lOpx79r6SDY35Y4RBec5c+bgpz/9aeTrBx98MPLnL33pSyM5NREVGLkynHLB\nQy7oqtTAC68fjZxr3sxK7BbZyiTW3hEQ31YUvpb49eVk+5c9viBuXVgDrUYje14p8ZXQxKgAPHpn\nA2qqTNBo1NDrNFhQXyX5uQV1Vk5n5yhWCCOinCBXhlNsL3GY3Eg3nMUcPtetCyejualGci9v9Kg9\nflsRAHT3uSN/Tmd9efeRDty7dlbS7UrxswceX0DR2rGl1ABrXAWvTatrEQyFsO+DK5HEMoOgwfK5\nnM7OZQzORDTmlCZrSYke6fb0u6GSKNBx9FQ3nnvgpoTgGAgGsXVXu+iovaLMkDCir59qTqsUZ9uZ\nHnh8AckymFKzB6saJyddOwbEE7s0ajW2rKnHxltqYbM7AZUK1vIijphzHIMzEY25kXZ80qjVkZHu\n2Y4+vPD6UdH39Qy4I2vQ0eeTG7UHAsGEgiL7jl+BRF0SWcnuRap4idcfkK0IJuhUWNEwSXYkrNdp\nUFOVv/t+Cw2DMxGNuZF2fArT6zSYPrlMeg0awHdePxpJmtq0uhb+QEhy1L637RK8EoU7lBYLiSZ1\nL0Mj91N496h48ZK2092yFcG8vhBUKhUTu8YR/iaJaMyF143FpLoHV+5c4QAXHpFu231adtTu8QbT\nCsJSpO4lvE9bKgD3ObwoNwriB69pbe+KbBsDALfXj067M+Y1yh8cORNRTshk44XwZ9rOdMPW64IK\n4lPCre1dWLds2oi6RSlRbhTQNKtK9F6UFC+xlBrQUFsRU2glXnjKPLxG3namGza7S3HWO+UWBmci\nygnR68YjbbwQPtfnNhTh4LEOfEdiDdo+4IbL45fM9s6EcqOAZ+9bDFOx+MhXSUWy6IeUd1vFR9jh\nKfN0s94pt/AxiohySjiTOZXAHK7SJTaFayrWoUKi4lc4oG1aXYtbF05GNgaWbq8fb+w9i8vdg6LX\nl6wi2c3zqrGqcTL8gRDuWlWLxTdMEH1vY10lAOktXvHT3pTbOHImorwltfXozlumY8cfz0amdvWC\neNSNXgMOBkOi3Z1SMVTeU4VA1NDW7R3K9t7TeikmES08xSy3T3tSZQk+/Isde49dgV7QAAjB7Q3C\nIKgBqOD1BWKm/7v73CPKeqfcweBMRHlLagr35PleXOh0RF53X8u41qiBwLUAbBA0CIVCCASDQxnb\np7pGdC2L6q24Z00dnvv5Ick9yVJTzGLr7cUGbdw9RCd7Dd3EsjnVuDeq7Gemst5p7DE4E1Fekkuk\n6rA5RF8PRI2M3d6h5hHBENBUZ0WvI3mRDznGEgFeXwB2BcVC4gurxK+3F+m1+Porh5Ke5+T53piv\n5Ubh7DyVXxiciSjjUm1ekY6efrdkhrXcnuB477Z2yGZBp3Ieh9MLs0yjiTD7gBu2XhcErTrmZxRe\nb+9U0OEKGPoZxE9VR2eqd/W6RpT1TmOHwZmIMibd5hXp2HVEOrtarppWvFQCebLzHDphg0bBbQo6\nDb63rRV2hw8Wk4AF9VUxPyO56eloekGTMFUdnal+5q/dWX1AouxhtjYRZUx4Dbi734MQYot9ZJLH\nF0Dbaek14urKsUt6CihIKnN7A7A7fACAngEvdh2+iNfeORXJOgcgWUhFKYOgTTnrnXIHR85ElBEj\nbV6RCrkpbQCwlhlwyebMyPfKJLUKQAgQi99/bO3A0XYb7ANeWEr1mDezErcunIyWk12wO8Tv1Xtt\n+YAZ2OMPR85ElBFKmldkytuHz0seU6uAY6d7Mva9MikoEZiBofaWPQPeyIzD7iMdUKlUeOa+RZKl\nO5mBPX4xOBNRRsgV05ALIh5fABc7B3DR5lBUJMPjC2D/h9K9jTO1hpwLWtu7IOg0aJpVJXqcGdjj\nF6e1iQjAyDOsU93GEwgG8do7p7Dvg8uRfbsGQYPlc6tx960zJRPIbL2umD2/41l4xiGTdccpPzA4\nExW4TGZYpxJEtu0+jd1HYrcwhfceq1QqbG6uE39gCOXe0DiV7PBUhGccMll3nPIDgzNRgctkowSl\nQcTjC6DlpPTUdGu7DYFAEG1nuhMeGKzmYhgEdWS0nQuyNZUeP+MQ3gdN4x/XnIkKWLIMa6WNEuIb\nTyRrXtHn8MgW6uju92BP6yXRLVl6nQY33Sje/GG8ELRqNDfVcNq6gHHkTFTAlGRYy43U0p0SL9Jr\nIWhU8AbEh5xS08QtJ20IBEM4fjY3s7GTUauGZuXNJj2cHr/o2rleq8a3/mEpypmFXdAYnIkKWCqN\nEsTWf1OdEo8O5lKBGZCeJu4Z8GSk1OZYWXzDBKxfcT3KjHr8+t0zoslzK+ZPYmAmBmeiQqYkw1pq\ndLx+xfUpFx2JD+ZSykp06Bv0JbyercSr0WAQNNiyth7F+qH/dpmBTXIYnIkKXLIgITU6drr9KU2J\ny61vR6soNWDuDDP+2Ho54Vi+BmYA+FjDxEhgBpQnz1FhYnAmKnByQUIuoJ44Z1c0JR6eDvf6Aoo6\nLTXWVSKYg9ul0lVRKj8iZgY2iWFwJiIA4kFCLmGs1+HB4hsmoPujqwnH6qeWIxAMYeuu9pjpcL3M\nFihTkRY3Xl+B25dch+d/cXjkN5QDyo0CvvaZJpiKxctvEklJOzgfOHAAjz76KGbOnAkAqKurw7/8\ny79Eju/btw/f/e53odFocPPNN+Ohhx4a+dUS0aiSSxgTdBq0X+wFMLwWrNepgJAK7x+/gpb2zphA\nnKz94YDLjwMfXR1KFvPnzh7mkegf9MLl8TM4U8pGNHJevHgxfvCDH4gee+655/DSSy9hwoQJ2LJl\nC9auXYvaWiY6EOWD6MxsqYQxtzcQ2QoUXgv2+EIAQteOiwdYg6BGKAR4fOLHMxGY9Tq15PnV6qEm\nE+kyCBrF5UPZmILSlZVp7QsXLqCsrAwTJ04EAKxcuRLvv/8+gzNRjhPLzJ4/sxKrF07GsVPdsA+4\nUW6U3qOrRCYre5mNAubXWdF2ujsmmW353Go8+zPxqfFQEBC0Knj9qa1rq1XAyvmToFar8M4RZdu5\nGmZYmORFaRlRcD59+jT+4R/+AX19fXj44YexfPlyAIDNZoPFYom8z2Kx4MKFCyO7UiLKOrHM7HeO\ndKC5qQbPPXDTUGKXP4inXzo4hlc5zHXtAeHhDXOgUathLS+CVqPC1rfbJbddhQD4ZfZYSwmGgOam\nKagyF0GlUqG1vQs9A27ZUt/NTVNS/j5EwAiC87Rp0/Dwww/jE5/4BC5cuIBPf/rTeOuttyAI6a2t\nmM3F0Gqz94RptZqydu58UMj3X8j3Dii/f6fLiz+3JW5fAoDWU134zLo5qJlUDrfXD6u5CJ12VyYv\nMy1ubwB7Wjqwp6UDVnMR5s6ohKBTY0/rJdnPpbsl670Pr+LzG+bh0XsWwu3140r3IL7+0/2w9boT\n3qtWD73/wfVzodGMTaVk/t3P3/tPOzhPmDABt99+OwBg6tSpqKysxNWrVzFlyhRUVVWhq6sr8t6r\nV6+iqkq8H2mY3e5M91KSslpNsNkGsnb+XFfI91/I9w4ov/9AMIiv/fSg5FR1d58bD39nN5pmVWHT\n6lo0zKhQVExkNNnsLuw+nN0Zuv0fXMa6pddFpqpLtGrMq60U/VkEg8Af9v0VXq8/5QYimcC/+7l/\n/3IPD2k/zr355pt46aWXAAxNY3d3d2PChKFi9DU1NXA4HLh48SL8fj/27NkTmfImotwQ3axi665T\nuNwj/4Dc6/BGmk9sWl2L5qYaVJQaoFYNJUllmirjZxy5ngEP+hyxWeebVtdiVeMkqCUuOJUGIkRh\naY+cV69ejccffxzvvPMOfD4fnnnmGfz+97+HyWTCmjVr8Mwzz+Cxxx4DANx+++24/vrrM3bRRJS+\n+KQvs0nAoMuv+PPh0pzRhUuMxTq8sfcvOHyiE70O6W5TqcjFMiRq1VDTjmgatRprF0/FHyWm0pU0\nECGKl3ZwNhqN+PGPfyx5fNGiRdi2bVu6pyeiLIlP+pJr3SgmOthEFy7Z3FyHdcum4emXD2YsQOea\nYAii+5ZTaSBCpAT7ORMVEKX1reWIdasKT4+bigUUG8Zv4cGKUr1ooA03EBETbiBClIrx+6+IiBLI\nleNUSq5bVZFei0td2UvuHGuNdVbJQMsuU5RJDM5EBaTMqIfZJIhOZet1ahiLdOgZ8KC8RI95MysQ\nCoVw7HQ3+hxeWEqTd6sCRhb4c9nK+ZNkAy27TFEmMTgTFRC9ToOSIvHgXGUuxlP3LoxJ8Gptt6HP\n4UW5UY+G2gpsWl0LjVqdkenxfPOJm6ZCo06+EsguU5QJDM5EBcTjC8Dp9okec7p9sNmdsJqL8et3\nz8SMiu0OD/a0dECjVmFzc11GpsfzidmoY1IXjSoGZ6ICIhdUu/s9+NrLh2AxCXB6xPfltrZ3Yf2K\n6fjDgXNQqSBbunI8MRbLT1FHNwrhVDZlAoMzUQGR2/ITJre1yj7gxvM/P5y0YMl4M+jyweMLJARe\nsaS4xjprZPqfKF3820NUQOS2/Cih1agKLjADQK8jsTIYMJwU193vQQhDsw/hKmpEI8HgTDTORO87\nFhNdelOVYo1Mf7odI/KcWCERuaQ4luykkeK0NlEOS2UtU2yKdfm8yVi3NDbLOLzlZ/2K6/HLnSdx\n4KNOxaUyg5lrxZxXxAqJyK3fs2QnjRSDM1EOSmctU2zf8Zt7z8Lp8op2RXpj71+w/6POrN1DPlOr\nhmp7W2QKibBkJ2UTgzNRDhILtOGvxQJtsinWDStnxIz8CnGfcipWzp+EtYunxsxYDDi9uNjpQE2V\nEaZiIbJ+L9YukiU7aaQYnIlyTKqBFkg+xWrrdUHQqiPBps/hkc3YLiRTqoxwuv0JJTfDMxRevx/P\n/6IFHTYHgqGhUfVkqxFf+fQCluykrGFwJsox6axlyk2xCjoN/u1XR2Ef8Eamx29fMhVq1VCXpZEQ\ntGp4/bm/EF1jLcGDd8zGntYOtJ3uTgik/kBIcm3/+V+04EKnI/J1MARc6HTg+V+04Nn7FrNkJ2UF\ngzNRjklnLVNuitXtDcDtHcocDk+PO93+EQdmAPjnTQ048FEn3j16KSPny4abGybi3tvqoVGrce/H\n6+FZlZhkp1FDNHlrwOlFh82R8DoAdNgcGHB6I1PcTP6iTOJWKqIck077wUAwiFAoBIMwfMwgqGEQ\nxP+Jnzhnh8UkiB5LxSt/OIlFCLMMAAAgAElEQVSb508as0phgjb5XjCNJvY94UCqZIR7sdMh+dAR\nDA0dJ8oGBmeiHBS9F1mtAipKDWhuqpFcy9y2+zTeOdIRGSEDgNsbhNsrPuXc6/DghussI77OK3YX\n/s8vWyDoUtwwnQHlJTo01FZCp5X/b2xP66W0i4LUVBmhlrg1tWroOFE2cFqbKAeF9yKvWzYtJkM4\nnscXgM3uTDnzWtBpcM+aOpzvdMSsp6bD4xubNefeQR8On1B231KJdMmYigVMthpFf0aTreK/E6JM\nYHAmyiHhoiPRLRvF9jlH74NOJ+s6FArhavcgBl3SdbTHk5EUBfnKpxdIZmsTZQuDM1EOiC86ohc0\nMVPU8fuc4/dBp8rjC+Ibvzgy4uvOFyMpCiJotXj2vsUJ+5yJsolrzkQ5YOvb7TENFKIDc7TW9i4M\nOL0sIJKiTBQFMRULuGGahYGZRgVHzkRjKBAMYuuuU3j36CVF7+/pd+Nip0NyH3Suq7YUodPuyvq2\nK/W1XtNWcxEaZlSwKAjlHQZnojHi8QXwy50n8d7xK4o/oxc0qKkyJu3JnItunj8RaxdNRZFei1f/\nvxM4cd4OlzeYkWIo8UIAHr97PhbPm4yBPldmT040ChiciUZZeH35yImrsDt8KX46hDf2nsWgO9XP\njb332i7jT0cvJ7yejVF0eYke0yeXwSBoMZD50xNlHdeciUZZOJkr9cA8tHd5T+ulhP3LGjWwsnEi\nKkpztxNSIMmOK71O+X9HS26sgtmokzw+n40nKM8xOBMl4fEF0Gl3wuMLiH6d6rmykcwVCAJqlVqy\nslg+ULJf2iBo0NxUg/s/eSMWzpog+p4pVUZsbp6Z6csjGlWc1iaSEL+9yWwSUFIkwOn2Ke6xHC+b\n3aBaT9owt9aSN80o4pmKdBB0atmfT4lBiw0rZ0CjVsd0hOrpd6PMKKBxZiU2r6lT/PsgylUMzkQS\n4vcS9wx40TMwXLQjWY/leIFgEDsPXchKAhQA9A56sfeY8uSyXNNYXwlBq5Hdv20f8ESKiYSrqLEj\nFI1HIwrO3/72t3HkyBH4/X587nOfw8c//vHIsdWrV6O6uhoazdA/lhdeeAETJohPQxHlmlSmn5WW\nhty2+zT2tHSM6LoErQpef462f7rm5saJOHuxHxdtg4o/o9WocO/H6wEAgWAI77Z2iD7AiBUTYUco\nGo/SDs779+/HqVOnsG3bNtjtdnzqU5+KCc4A8OKLL6KkpGTEF0k02uR6KsdTUhoyU2vNC+qtOH6m\nBw63f8TnyhatSoWnP7sIW99uR+upLvQ5vBB0atk15dJiHfyBELQaFTRqFXRa8fdnopgIUT5IOzgv\nWrQIDQ0NAIDS0lK4XC4EAoHISJkon8n1VI6npDSkXLBXqQBTkYB+p3yda4OggVarzunADADvHb+C\njatm4t61s3DX6qFa4V5fAF97+ZDkZ+wOL/ocHuw6clF0WnsoG30yi4lQwUg7OGs0GhQXD40UduzY\ngZtvvjkhMD/99NPo6OjAwoUL8dhjj0Glkm4rZzYXQ6vNXmC3Wk1ZO3c+KOT7T/fel8+bjDf3nlXw\nvkmomVQu+x5TWRGs5qHqWAnXV16EuuvK8WeRPcDRqiuKse+D3F9T9niD8KtUqCwrwmC3EyUmA2pM\nBljLDbD1ukU/Yy0vQs2kcrT96pjo8UAQMOh1qJ5QlvL18O9+4crn+x9xQtiuXbuwY8cOvPzyyzGv\nf+ELX8CKFStQVlaGhx56CDt37sRtt90meR673TnSS5FktZpgsxVuKYJCvv+R3Pu6pVPhdHnR2t4F\n+4Ab5UY9Sop0cLp9sA94YDYZ0FhXiXVLp8JmG4h0lJJKTGqYUSE6Kuwf9CQNzBMtxfjr5fz5Hf7i\nDx/hg9Ndkf3YBkGDynKD5PsbZlTg4qVe0YeXsPfbLmPd0utSmtbm3/3CvHcgP+5f7uFhRMF57969\n+PGPf4yf/vSnMJliv8n69esjf7755pvR3t4uG5yJco1GrcaGlTNwc8NEQKWCtbwIep0mIQgP1cdu\nl2zvGHbnLdNx8nxvpPVgWHxBkXhmowCvP/U91WPp0EedMV+7vQFc7BxEtaUI9gFPZD3ZIGiwfG41\nNq2uhT8QQrlRQK9DfHq/d9CTdttHonyTdnAeGBjAt7/9bbzyyisoLy9POPbFL34RP/rRjyAIAg4d\nOoS1a9eO+GKJRkv8HufogBufHRy/5Sp+i1U4mO88eB4XOh0pX8sN0yx4P4X627nsSo8LZpMejTPL\nsPam61BtKY6MhDVqoHFmJfa0ijcBsYyg7SNRvkk7OP/hD3+A3W7HF7/4xchrN910E+rr67FmzRrc\nfPPN2LRpE/R6PW688UaOmimvSAXcQDAU2fIDAANOLw6f6BQ7BVrbbQgEgmg7042efg9kUi4kLZ9T\njXvWzMTJ8/a8a3QhxT7gwf6POqFRq7FlbX3Msc1r6nC6o1/0IYaZ2lRIVKFQKCc2TWZzbSAf1h6y\nqZDvP5179/gC+OqL+0WDoVoFLLphAjavmYk39v4FR050ot+ZnSYUpmItHr2zAZOtppS7V+ULi0nA\ngvqqmCWAcBvNo+1d6B30wHJtbT+VSmxh/LtfmPcO5Mf9Z23NmSjfhaeci/RauDx+lBn1stuegiHg\nwEdXceCjq0nPrbrWUzhdA04/nvtFCwyCGotvqIJBUCddn843PQPehCprGrUad62qxarGyUAoBKu5\nmCNmKjgMzlSQwmvKLSc70TPgjZTUrCjVY870CpSVCOgdlN93nEym5qTc3iD+dOwKplQZ01qzzgfh\nKmtajUpyrZ/1sqmQMDhTQYpfUw5nT3f3e/DuUfGEpLE2MDg+1pzFhKusxRchSbV+OdF4wUdRKjjZ\natuYbb2D2VnbzgVmkwFFeq3k76W1vSutFp1E+YrBmcalcM9ltzex1GUqdbMzqdwoYGXjJOi16f2z\ny9VlV4Mw8v9GGusq4fL4JX8v4ZE1UaHgtDaNK/H7k63mIjTMqIhZs0ylbnaYkh7J1RVF6OxxiXZT\nUqmAL909H5ayIrSf68XlntQr4uXqwLGi1ICOLvn7MQgaeH0BmE16FBt0GHT50OsYrrIWLkIi9XtR\nUr+caDxhcKZxJX4tudPuSliz1Os0aKyzyvYNjmYx6THnegsOfHQVnmsBWqMeanPo8YWgAhAC4HL7\nJfs0h0LAsz8/BBVUst2Z8pHD5cOqBZPRdrob3f3uoZkBNeDzBSPBd/2K6XA4vZGqamKlTjVqSP5e\nuMeZCg2DM40bcmvJ8T2Xw92NWk7a0DPgiWRriykp0uFPbbG1rwNBYIKlCJe6nAh/rC/JmrDXFwKQ\nE2UFMqp/0Ie1i6bgrlW1kYALICH4FuuH/7uR6sEc/r2E65lHj6yJCgmDM40bcmvJ8T2XNWo1NjfX\nYcPKGZF9zg6XD7sOX0DbmZ5IYGiYYUHbmW7Rc15KMpVbKCylhkgQjg646dTAjv+9SDURIRrvGJxp\n3JBbS5Zas4wOKKZiAfeunRUz5WqzOyVrPcspLdZlrXJYrsnGlLPUyJqoUDBbm8aN8FqymPgAEs7m\njt+eEw7MxmIBv373DL63vS3l67CY9Nh068yUP5ePjEVaTjkTZQFHzjSuxK9ZVpYPZ2sD0t2m7rxl\nOnb88Sxa223o7vcoys6W4nB58dPffZSxe8plTrcfTrcfpmJhrC+FaFxhcKZxJX7Ncsa0Cgz0uSLH\npbpNnTzfG1MaM93APPTZ8Zf0JSUYAi52OnDDNMtYXwrRuMJpbRqXwmuWBmH4+VMum3u81qzONrUK\nqKkyjvVlEI07DM5UMMaqMthYaaqvhEFQlqhVbhSgVg01/qixlkCvG/6vwSBoYCwSn2SbbDVySpso\nCzitTQWjzKiH2SSgZ2Bk3abyhVarwavP3ob/OW2D1+fHd147KloAxSBo8Ox9iyMtM8NFQmx2J6BS\nwVpeBJUqhOd/0YIOmwPB0NCIebLViK98esEY3BnR+MfgTAVDr9Ng1nUW7Dt+ZawvZVS0n+8FANRY\njRhwehGSqrICQNBpYkbAep0GNVWxjeCfvW8xBpxeXOx0oKaKI2aibGJwpoKyec1MtLTb4PbmaKHq\nDOp1eNDV68L2XSfx52OX4Q2IB2ePNxBToEWOqVhg8hfRKOCaM41LUl2p9DoNKssMY3RVo8tsMuB3\ne89i95EO2ezzcIUvIsodHDmTLLEGBbkm+hq1GpVsV6ptu0/jom0w49dQXiKgpEiLrj53zjS2mD3d\njP0fJK9uJlbhKx9+70TjGYMziZIq1hHdenGsiV1jsUEXsy0quivVhpUzJLdSjVTvoBcurz9nArOx\nSIu2093odcgnvy2bUx1T4Ssffu9EhYDBmURJFesAhlsvjjWxa5Tq0dza3oWlN07I6laqXAnMAOBw\n+ZO/CYBOp4r5Oh9+70SFgI/ClCBZ68X4etRjQe4axXT3u/GDX7eNw4aNI/Nu62Vs230aQH783okK\nBYMzJVDSenGspVNQJFm/5ULVctIWWWPO9d87UaFgcKYE4daLYqRaL442uWuk1NgHPJHkr1z/vRMV\nCgZnSpBK68WxIneNwFDVK7UKqCiQbVMjYTbpI1nZuf57JyoUDM4katPqWjQ31aCi1HCt5rIBzU01\nOdW7d/2K6yVrR5cYtHjms4vw/X++BRUcYctaUG+NBN58+L0TFYK0s7W/+c1v4tixY1CpVHjqqafQ\n0NAQObZv3z5897vfhUajwc0334yHHnooIxdLoye+9WIu7nd1OH3wSFT6Cq+dlhn1mDXVjPcKpGQn\nANRPLcPJ831J32cQNFg2N3YrVT783okKQVrB+eDBgzh37hy2bduGM2fO4KmnnsK2bdsix5977jm8\n9NJLmDBhArZs2YK1a9eitpZP3vko3HpxrMgVwwivkYptnwoB+P6ONiw/1Y31N08vmOCsVgGf/cQN\n+MqL+xEQ2dmlF9T40j2NEDRqWM3FkoF3rH/vRIUureD8/vvvo7m5GQAwY8YM9PX1weFwwGg04sKF\nCygrK8PEiRMBACtXrsT777/P4EwpUVIMI7xGGr0vN1p3vwdv7j2LQx8WRmAGhjpFVZmLcUvjZLxz\npCPh+MfmTsT0iWVjcGVElIq0gnNXVxdmz54d+dpiscBms8FoNMJms8FiscQcu3DhQtJzms3F0Gqz\nN31mtZqSv2kcy7f7f/GND0SLYRQXCXhg/dzI6w/f1YjiIgHvf3AJtl636Lku9zizfr1jTa0GplWX\n4juPrIAgaPHIpgUoKdZj//HLsPW6YC0vwpI5E3HfutnQaAor1STf/u5nUiHfO5Df95+RCmGh0MhL\nO9jt2fsP1Go1wWYbyNr5c12u33/81LXT48NbB86Jvve9Y5fwicVTYqZjP7F4CqZXG/Fv29tG65Jz\nyv23z0JDbSVMxQL6+lyR19cvn4Z7b78BZ/7aHfnZ9vRkvq54Lsv1v/vZVMj3DuTH/cs9PKQVnKuq\nqtDV1RX5urOzE1arVfTY1atXUVVVlc63oXFOaura4fZJtnQMF8OoMhcnfL4QVZQa0HTDBMm1Y4Og\n5doxUR5Ka35r+fLl2LlzJwDgww8/RFVVFYxGIwCgpqYGDocDFy9ehN/vx549e7B8+fLMXTGNG+E6\nzt39HoQwPHXderJT8jPlJj28vgA8vkDC5wsR9x8TjU9pjZwXLFiA2bNn4+6774ZKpcLTTz+N3/zm\nNzCZTFizZg2eeeYZPPbYYwCA22+/Hddff31GL5ryn1wdZ49POtQ6nD48/fIhmE0CnJ7Cq/WsUSOS\nhW0Q1AiGQggEg+wYRTTOpL3m/Pjjj8d8PWvWrMifFy1aFLO1igqT3DaodGpjA4DXPxSZegbkWyGO\nFwZBA68vALPJgCK9JqYXtdsbxO4jHVCrVOwYRTTOsGUkZZySbVBye5QNgkZyzXm8U6uG9mhbTAY0\n1lVi/Yrp6HN4sPPQefz52GXRz7S2d2HDyhmc3iYaRxicKeOU9ASW26NcWW5AV687EqD1WjU8/uz3\nSv7nuxrQM+DBG386g95BZf2QM23l/ElYu3hqzGzDG3vP4k9HxQMzEJskR0TjA4MzZZTcWvKREzas\nWzYNpmIBACJlI1vbu2AfcMNsMqDYoMWFTkfsOf1BGAQ13N7EAG0QNCjSa2Af4TS3SgW88t8n0DPg\nRaa3AWvUgFajhscn/4BhEDTYcEstivXD/yyV9K1mxyii8YfBmTJKtieww4OnXz6IpllVkSnu6DrO\nRXotvv7KIYkzq0RftZYXYVJlMQ58JJ3hrUQoNLyOLVb2MhUq1dD5wl2xnti8AIJWjadfPoheh/RD\nhNcXgMPpjQnOStbmmbFNNP4wxZMyKlmf5V6HF7sOX8R/vH0y8lq4jrPL45cMRF5fANWWooTXL3Q6\ncPRUl8gnxk64Jk8wBNh63fiXn+7H7/b9FQvrpVtcAuIjYLmfp1oFrGqcxI5RROMQgzNlVLI+y2F/\nbL2MV986iUBweJhqLNZBL9ECUqdV42qPS/RYsuniseb2BrHr8EWEADQ31Ui2uWyYYUGfwwOPbzgZ\nTu7nubJxMu5dO4vbqIjGIU5rU0Z5fAGsapyMQDCElnYb+mSmcfe0dECjHt4G9Js/nZXM0s71AKzE\nsVPdeO6Bm7B+xfXY+vYpnDhnR6/Dg3KjHiVFOrSd6cYfWy8lZLeLrc031lVyxEw0jjE4U0ZEb5/q\n7vdA0Krg9Sev2xXeBgQA+z6QzkiWky9br6Kzqv/+kzdG9oHvPHQBe1qGO0jFZ7ezxzJR4eF8GGVE\ndClNAIoCMzAcsGx2p2g2thJL50xAjbUkrc+Opvg1Zb1OgzKjHm2nxdfMW9u7Eqa4q2R6MBPR+MHg\nTCOmZLuPlEjAUolnYyvh8wfh8ozNvuRUiGVVy2a3X3twIaLCw+BMafH4Aui0OyNTs+l2hQoHLGt5\nETRq8QCdLGy/13ZFtNKYEjoNUG4U0vqsHL1ODUE7fOUGQYPQtTrY0eSysbl/mahwcc2ZUiJWmrOh\nthJmk6Co3nV8ecropCadVoWAN3E6XC+oMb/Wiv0fXRU950g6UvkCkN17nK4qc3FMMRW3N4B3jnRA\nFVcHW6tRodigE3244P5losLF4EwpESvNuaelA1OqjIqC88caqnHTDdWoqTJGKoUBQ9O7UmvObm8Q\nK+dPwoGPruZ0a0iVauiho6G2AsdOiU/zRyfA9Tk82HnwfEJFNACYUmVkNjZRAWNwJsXk1padbh9W\nNU7C+x9eFc2cVquBSRUl+PAvduw9diVhu5CxWCebdf3j3x5XFJjD1blGm8WkxxfvmoeyEgEXOx0x\n2dfRevrd+OXOkzhx3o6efo/kUrvT7Yc/EMp4KVEiyg8MzqSYfPKSB2sXT8WGW2rx2tvtQ8FnwIOy\nEgGzppZDr9fi3dZLkffHbxd6Y+9fZLdD9Q36FF3jWARmAJhfV4k/HbsUme5Xq4YqhMXT6dR47/iV\nyNdS18tmFkSFjcGZFJNr8xhOXtLrNLj/2h5eW68LCIVQZtRL1sxube/CumXT0HJyZLWxR5v62gjd\nUjq0dh4KhWKm+6WCrldhMRUmgxEVNgZnUkyuzWN08lIgGMSv3z0TGUWWGQXJpCv7gBsXOx2K1qtz\nycrGyVg1fxKgUqGsRJB8+JAaQSfDZDCiwsbgTEmFt0uVGfWKSknGJ43JZUObTQZUmYvSDmIjoVYD\nQZmBrFoFlJYMPViEr89i0mN+XSVUAL6/oy3pw0cIQGmxgH6n/MNH/EicyWBEhY3BmSSJbZsKJ3FJ\nlZJMtSBJY10lXB7/qAdmYCgw11SVoMM2KDoNfcuCydh4S22knaXL40eZUY9fv3tG8cOHqViH/sHk\nswIrGydj7aIpLM1JRAAYnEmG2Lap6CQusWSlPodHtiCI2ahH36AHZpMBc6ab4XT78b3tbZm/eIVc\n7gD+n4eW4Ve7T+PDv3RjwBWA2ShgYVTP6fB9moqFlB8++gd9srMCFpMeC+qHs9aJiAAGZ5IgF4TC\ne3XjR3iBYBA7D12QDEYVpQZ87TNNcLh82HXkIt4/fjntetqZYh9ww+UJwFgsQNBpoXIFoJaoVAbI\nZ6wDQw8f9riSm1KBefmcamxZW8+RMhEl4KM6iVJa8zm6jOe23aexp6VDMhg11lXCVCxgT2sH9rR0\njHlgBoCyEj12HjofadoRwvAMwbbdp2PeGwgGsfPgecm9yRWlBjx17wLJcqBq1VAp0opSA5qbavCZ\n22cxMBORKI6cCUBs0le4W5LctiljsYCtu9qHM7JLBLi80s0nJltLsGl17YiaZGSD3eHBn4+Jt6qM\nnyHYtvs09kTt1Y7XWFeJQDAk2cM6BODxu+dj+uQyBmUiksXgXODkkr6ktk0V6TX4j7dO4v0Ph2td\n9yZJehp0+eDxBbD17VNpN6nIFqmRfnQhELmHCrVqKKFr0+pa+AMhyYcai8nAwExEijA457H40W46\n5JK+Nq2uxcnzvQm1ny/aBnHRNpjS9+lzeLH17VPYF1UdK9eVG/WRQiBy0/yhELB20RRo1Gpo1FC0\nF5yISA6Dcx6SG+2mkvGbLOlr3bJpcLqVlc1Mxlyqx4lzPSM6h0EYCmweXwAqZH9fdEmRLhJM5ab5\nLaWx1byU7AUnIpLD4JyHkm1xUipZ0tfFTkfafZrjOZw+eP3KE8AMggZeXwDma12emhfWwFJqiFz3\nHw6cw5+Oiq8VZ4rTPTQVr9dpFFdHAwCNWo3NzXWSe8GJiJJJKzj7/X585Stfwfnz5xEIBPDEE0+g\nqakp5j2zZ8/GggULIl+/8sor0Gj4H9RIpbPFSUqypK+aKqPkcTla9VBWcnQZaaWBWaUCbmmcjA0r\nZ8Dh9IoGtipzMdYumpr14Gwf8MQ0n0h1RKzXadi4gojSklZw/u1vf4uioiK89tprOHXqFL785S9j\nx44dMe8xGo149dVXM3KRNEzJFielASHZaFDQaVA/1ZzyOnEKA+QEEyxFuPfj9QCAYr30X09LqQEG\nQZ32dixBq4I/EJKdGo9vPpFsRJyJHAAiIiDN4HzHHXfgk5/8JADAYrGgt7c3oxdF0pR0hkqF2Ghw\n/swKBEMhfPXF/ejp90TWeuVaOmaKvd+DAac3UipTPshJFwvR69TwyHSA8vmTL1hLJXDFj4gzlQNA\nRBSWVnDW6XSRP//85z+PBOpoXq8Xjz32GDo6OrB27Vp89rOfTf8qKSKVtU8lxEaDv373DN6JOn84\nKC+bUw29To22Mz2wD7gh6DIftD2+IJ5+6SD6Br2yQa7P4YFH4vuqVEBTfVVM3+R4ZpMeKhVEH3LU\nKmDl/EmKE7gylQNARBSWNDhv374d27dvj3ntkUcewYoVK/Af//Ef+PDDD/HjH/844XNPPPEE7rjj\nDqhUKmzZsgVNTU2YO3eu5Pcxm4uh1WZvKtBqNWXt3KPt4bsaUVwkYP/xy+jqdaGyvAhL5kzEfetm\nQ6MRH6kpuf8aAG6vH21nukWPn+7oww+fWA0AuNLtRCAYwH/vO4cDx6+g15G5vcvhPdPhIFdcJOCB\n9bF/d0xlRbCai9BpdyV83lpehEfubkTFzpN4++A5uDyJQfxj8ycDAN7cezbh2G1Lp+HzG+Ypula5\nn1fbmW58bkMRDMLY5l2Op7/76Sjk+y/kewfy+/6T/q+xceNGbNy4MeH17du3Y/fu3fj3f//3mJF0\n2D333BP585IlS9De3i4bnO12p9JrTpnVaoLNNpC18482jy+AZTdW4dbGSTHTvz094nuPU7n/TrsT\nNpGABwBdvS60n+3CntYOtLbb0i4mkmp7yPeOXcInFk9JmBWYO92Cd450JLx/7nQLnA4P1i+fhs1r\n6/H/vt6KE+ftsA94Iklc65ZOBQA4Xd6EBK9PfWxaxn5eZ/7aPaZJYePt736qCvn+C/negfy4f7mH\nh7Qe6S9cuIDXX38dv/zlL6HXJ65xnj17Fj/84Q/xwgsvIBAIoKWlBbfddls634qiiK1tNsyoQHPT\nFFhKDRlJQkq2pr3r8AXZEpZKTLYaEwqbyOnpF090k4rv/mAQnXYnyox6WIsE3P/JGyWTtUa65SnT\nOQBERECawXn79u3o7e3Fgw8+GHntpZdewiuvvIJFixahsbER1dXVuPPOO6FWq7F69Wo0NDRk7KIL\nldja5p7WS9jTegkVGUpC0us0aKitxJ4WkRHpDEtMyc50LJ9TjXtvq8P2PWfw3gdXIuvVekENny8o\nOqLWC5qEIOfxBXBUYkvZ3qOX8afWy7CU6rF83mSsWzpVdlvTSLY8ZToHgIgIAFShkFib+dGXzemH\nfJjeAOS34nh8AXz1xf1Jp5Kbm2oSkpCU3n94ZN5yshM9A97I9HNFqR6zpprhDwZx4KPOpOdRqYZK\nWsazmPR4/sElkXvz+AKw9boQCASx52iH5L5lg6DB9x75WORzgWAQL/3+I+xXcC2A+M8kk4ZnNBL3\nP491tna+/N3PlkK+/0K+dyA/7j/j09qUWUq24iTrIxyWaiGSaPEj8/Ao1uHy4r3jV2Q2LsWqNhfj\nck9iDsGCemvMdel1GtRYjdi6q122oIj32kNLlbkYgWAQX3/lcErT4iP5mSjBimBElGnchJkDwkFR\nrp9weG0zmehey6mQqzzm8Q1F6WRTLAZBA4OgwZUeZ+TP4f7FqxZMxqrGyfD4YjOnlbSQjF673fp2\ne0qBGUj/Z5Kq8PQ4AzMRjRSD8xhLVo4zHMzCa5vJpJuEpHRkLkavU+OmG6vg9gbg9gYQAiJ/Xjqn\nGg0zLGg73YWvvngAX31xP7buakcgGFT8fcNrtx5fAK2nulK+PiZmEVG+4bT2GEulHGd0Na/ufrfo\nZ9JNQpLLOk5m2Zxqyb2+Le22mCIl8QU65L5vdJ9kYOhn1euQ7xstholZRJRvOHIeY3LT1VK1nZ97\n4CY8/8BNWLVgMipKDVCrhqaOm5tq0m5LqHRkDgwFTVXU92xumiL5gCFVPSw8KyD3fVfOn4R7P14f\nWXcv0mtRbhQUXWNYkXhq+nIAABFXSURBVF6D9Sump/QZIqKxxpHzGEtnK45ep8HEihLc+/F6eFZl\nrtlCfJ1tQacRDa4r50/C2sVTI9/T4wukPOruiZoVSNbtKTphLtWRs8cbgMPplW2iQUSUa/g/Vg5I\ntRVhtEy2JYzOOu7pd+Otw+dx4MOrkc5PBkGDJbOr0Nw0JeFhYNZUs2wt63gqADsPnsfmNXVJs53j\ns8hTUVlexPVmIso7DM45YKRbcTLdqlCv02BPawfebY3d3uT2BrD/w068e63Ax/yZlQgBOHaqC939\nHhgENQAVvL6A5Kg7LBgC9rRegkajjuxBFnvQkEuY0+vUKDFo0evwSn6/JXMmcr2ZiPIOg3MOSXUU\nnK1WhXIBMRwAu/s9CXWtwyPsJbMnoP28XVG3qmR7kOUS5nz+IL5413wIWjWMxQLe2Hs2YfbhvnWz\nJWuOExHlKgbnPJatVoUj2VYFACfP9cKucG04PiM9nnztaj2s5UWRwC42+yDVpYuIKJfxf648pXR/\ntNznO+1O0fcpLXgixZ5CwY9ke5D1Og2KDYldzwCg2KBLGHGzEAgRjQccOecpudFtd78bPf1uTKwo\nSTgWPxVebtRjfl0lNjfPjEyFy2WQK5FKS8hke5A9vgAGXeKj8EGXL7Idi4hoPOHIOU8lG93uOnxB\n9PX4UqF2hwd7Wjrw9VcOR6p2AUMZ5M1NNZF91AZBeQBUEpjVKmDVgslJM9L7HB7YB8SDc6/DMypl\nOYmIRhtHzjkqWQa2XGtHAGg705MwqpSbCr/Q6cDWt9tx79pZAGIzyMOdo/7Udhltp7sjCVfzZ1Zc\ny9Yefq2htgLHTtnQIxFQw0IhYO2iKUkT19gvmYgKEYNzjkklA7t5YY1kcBZLtEqW6PXe8SvYcEtt\npGBHIBjEr989E3MtDTMq0Nw0BZZSQyTwb7wl9kFCo1YlnRK3lIoH1vBDSZFeC5fHjzKjnv2Siajg\nMDjnmFQysC2lBlSkMKosM+pRbtRLJmx5fUG89nY77v/kjZLXEr83GUjcApZODXC5XtLzZ1Zi9cLJ\nMSN0pUVaiIjyEYNzDpGbdj58ohPrlk2DqXi4tnSqpT/1Og3m10lPhQPAifP2SAa3XDa43N7k+Epj\nuw5fQNuZHtnAKtVLOryfurmpBs89cBP7JRNRQWBwziFy0869Di+eefkQFs6KneJOtfTn5uaZOPFX\nOy73OEWP2weGk6yUdsuSEqkBvnaW7Bq6kp7O4QeCTJUqJSLKZQzOOSRZ20a7I3GKO9XSnxq1Gl/5\nu4V47P++B48vmHA8PB0eCIagF9SRql9i70mlbKhc9TMlRU+UPhAQEY0HDM45ROn+4tb2LqxbNi2S\nMKXXaVIq/Vms12HFvEmS0+EAsPXtdtHADADzZlYkJIqNpGyokl7SzMwmokLC4JxjwtPRh090SrZH\n7O534+mXD6LP4U07MIpNh8+bWYFQKISvvrhfMlAaBA2CwRB2tw6vW4+0bKiShxJmZhNRIWFwzjHh\naep1y6bhmZcPSWZWhwN3KoExfho6fjr81++eSTpq93gDOHaqW/RYskQxOeGHhZaTNvQMeGKytcMP\nH0REhYLBOUeZigUsnKW8hGZ8YIwOxIFAEFt3tYtOQ4enw5UkZQFAmVFAr8QDw0jWhePXzqP3OXPE\nTESFhsE5h8VPPZeVSO9RDgfGijJDQhGTMqMeZy/1R94rNtpW2omqcWYl2s50Z61iV/TaefS2MSKi\nQsLgnMPERpNff+WQbGAUKxwitX4cPdpOlpRlMemxoP7a2rbmNCt2ERFlEYNzHogeTcoVHQGkC4eI\niZ6GlkvKWj6nGlvW1kcCb6p7q4mIKDUMznkkEAzCHwhCr1XD4x/a5mQQNFg2txqbVteiu8+taGo6\nLH4aWi7oRmeCp7q3moiIUsPgnMOik7q0GhW+/sphXOh0xLzH7Q3A6fLjctegov3C0eKnoVMNuqns\nrSYiIuXSCs6/+c1v8P3vfx9Tp04FACxbtgyf//znY97z5ptv4uc//znUajXuuusubNy4ceRXWyDE\nOlMZ9Fp02AZF37//o6vY/9FVGAQ1KsuKACQG5ylVRjjdfkXT0Ay6RERjK+2R8+23344nn3xS9JjT\n6cQPf/hD7NixAzqdDnfeeSfWrFmD8vLytC+0kIgldYkF3HhubxAXbYMJgXj5vElYt3Qq/IEQp6GJ\niPJAVqa1jx07hrlz58JkMgEAFixYgJaWFqxevTob3y4vKK1D7fEF0HKyc0Tfa9Dlw9OfXRTZJ1wz\nqRw22wA0anBETESUB9IOzgcPHsT9998Pv9+PJ598EjfeeGPkWFdXFywWS+Rri8UCm00+i9hsLoZW\nm73RnNVqysp53V4/7P0emEv1MAiJP85AIIiXf/ch9h+/DFuvC9byIiyZMxH3rZsNjUad8N4f/Ooo\negbEy3YqZR/woKjEgOnXlURey9b954NCvneA91/I91/I9w7k9/0nDc7bt2/H9u3bY177m7/5Gzzy\nyCO45ZZb0NraiieffBK/+93vJM8RCoWSXojdLt7CMBOsVhNstoGMnlNsXVisxvXWXe0xU9Sddhfe\n3HsWTpc3odzm1l3t2K2wIpgcs0mPgNcXueds3H++KOR7B3j/hXz/hXzvQH7cv9zDQ9LgvHHjRtlk\nrsbGRvT09CAQCECjGRr5VlVVoaurK/Kezs5OzJ8/P5Vrznli68LxVbfkSmKKldtMtkc5vJbc0++G\nTqeGV6TlIwAsqLdyTZmIKI+l3t8PwIsvvojf//73AID29nZYLJZIYAaAefPm4YMPPkB/fz8GBwfR\n0tKCpqamzFxxDkgWdD2+AAD5kpjhAiBhycpnLptTja99pgnPPXATvvW5Jfjuwx/DrQsnwyAM/9wN\nggarF05mMRAiojyX1przunXr8KUvfQmvv/46/H4/nn/+eQDAT37yEyxatAiNjY147LHHcP/990Ol\nUuGhhx6KJIeNB0qCbpW5WHbfcXwBELn3VpTqce/aemjU6pikrv+1ph533lILW68LCIVgvVbpi4iI\n8ltawbm6uhqvvvpqwusPPvhg5M+33XYbbrvttvSvLIcpDbpyJTHjC4DIv1d6mlqv06DGakz3VoiI\nKAexQlgaUgm6qdShZs1qIiICGJzTpjSQplISkzWriYgIYHBOWzbrULN8JhFRYUsrW5uGhQNpvoxw\nPb4AOu3OSEY5ERHlHo6cC4TSoilERDT2GJwLhJKiKURElBvG7ZCJ07fD3F6/oqIpRESUG8bdyFls\n+nb5vMlYt3Rqzk/fKu1clSp7v7KiKURElBvGXXAWm76VajSRK7K9HmwuVV6pjIiIxl5uDyVTpLTm\nda4JP1B093sQwvB68LbdpzNyfoOgRWOdVfRYfNEUIiIae+MqOKfSaCJXjNYDxabVtWhuqkFFqQFq\nFVBRakBzUw2rjxER5aBxNa2dSqOJXKG0icZIsfoYEVH+GFcj53DNazG5On0bfqAQk40HinwrmkJE\nVIjGVXAGxKdv71gxPWenb/PxgYKIiLJrXE1rA+LTtzWTymGzDYz1pUliNyoiIoo27oJzWD41j8jm\nerDHF8DlrkEEfAGOwomI8sS4Dc75KJMPFDF7pwc8sJhYS5uIKF8wOI9TrKVNRJS/OIQah/K1GAsR\nEQ1hcB6H8rEYCxERDWNwHodGe+80ERFlFoPzOMS900RE+Y0JYRmWrbaPqeLeaSKi/MXgnCHZbvuY\nqui90xpBh4DXxxEzEVGe4LR2hmS77WO69DoNJlaWMDATEeURBucM4NYlIiLKJAbnDODWJSIiyqS0\n1px/9KMfYd++fQCAYDCIrq4u7Ny5M3L84sWLWLduHebMmQMAMJvN+MEPfpCBy81N+dhHmoiIclda\nwfnzn/88Pv/5zwMA/vM//xPd3d0J77n++uvx6quvjuzq8kR461J0ucwwbl0iIqJUjShb2+/347XX\nXsMvfvGLTF1P3uLWJSIiypQRBee33noLH/vYx2AwGBKOdXV14Qtf+AI6OzuxefNm3HHHHSP5Vjkv\nm20fiYiosKhCoVBI7g3bt2/H9u3bY1575JFHsGLFCtx///149tlnUVNTE3Pc4XBg586duOOOOzAw\nMICNGzfitddeQ1VVleT38fsD0GoZzIiIiJIGZylOpxMbN27Ef/3XfyV976OPPop77rkHS5YskXyP\nzTaQzmUoYrWasnr+XFfI91/I9w7w/gv5/gv53oH8uH+r1SR5LO2tVCdOnMD06dNFj+3fvx/f+ta3\nAAwF8RMnTuD6669P91sREREVlLSDs81mg8ViiXnt+eefx4ULF9DU1IS+vj5s2rQJn/70p/Hggw9i\nwoQJI75YIiKiQpD2tHamcVo7ewr5/gv53gHefyHffyHfO5Af95+VaW0iIiLKDgZnIiKiHMPgTERE\nlGMYnImIiHJMziSEERER0RCOnImIiHIMgzMREVGOYXAmIiLKMQzOREREOYbBmYiIKMcwOBMREeWY\nggjO3d3d+Pu//3vce++9uPvuu3Hs2LGxvqRR4/f78eSTT+Kee+7BXXfdhcOHD4/1JY26gwcPYunS\npdizZ89YX8qo+uY3v4lNmzbh7rvvRltb21hfzqhrb29Hc3MzfvnLX471pYy6b3/729i0aRM2bNiA\nt956a6wvZ1S5XC48+uij2LJlCzZu3Ji3/+61Y30Bo+HNN9/E3/7t32LdunU4ePAgvv/97+Pll18e\n68saFb/97W9RVFSE1157DadOncKXv/xl7NixY6wva9ScP38eP/vZz7BgwYKxvpRRdfDgQZw7dw7b\ntm3DmTNn8NRTT2Hbtm1jfVmjxul04hvf+AaWLl061pcy6vbv349Tp05h27ZtsNvt+NSnPoWPf/zj\nY31Zo2bPnj3/f3v3D5JaFIAB/BNvRtHfK9ewLVqKIlqaoqJoimgTWguChhqL4g7NRrQooZiDQ2Bo\nBEFDEVE0BOGoREtLiFEXScqSQHhDcHnCe5EP3j3q+X7TuWf6DlzOxz2IB/39/VhYWEA6ncb8/DzG\nx8dFxyqbFOU8NzdnjjOZjFTXV87MzGB6ehoAoKoqXl5eBCeylqZp8Pv90HVddBRLXV9fY3JyEgDQ\n3d2NXC6Ht7c3NDU1CU5mDYfDgVAohFAoJDqK5YaGhjAwMAAAaGlpwcfHB4rFIux2u+Bk1piamjLH\n1bzfS1HOwNf904uLi8jn84hEIqLjWKaurs4cRyIRs6hl0dDQIDqCEIZhoK+vz3xWVRXPz8/SlLOi\nKFAUaba3Ena7HY2NjQCAeDyO0dFRaYr5d7Ozs3h8fEQgEBAd5Z/U3Nsbi8UQi8VK5paXlzEyMoKD\ngwNcXl5ifX29Jo+1v1v73t4eUqlU1b6oP/Hd+mXHf+mVz9nZGeLxeE3udT8RjUZxe3uLlZUVHB0d\nwWaziY5UlporZ4/HA4/HUzJ3c3ODXC6H1tZWjI2NYXV1VVC6/+tPawe+Suv8/Bw7OzslX9K15m/r\nl5HL5YJhGObz09MTNE0TmIisdHV1hUAggN3dXTQ3N4uOY6lkMgmn0wm3243e3l4Ui0Vks1k4nU7R\n0coixa+1T09PcXh4CAC4u7uD2+0WnMg6Dw8PiEaj8Pv9qK+vFx2HLDI8PIyTkxMAQCqVgsvlkuZI\nW3avr6/Y3NxEMBhEW1ub6DiWSyQS5mmBYRh4f39He3u74FTlk+JWqmw2i7W1NeTzeXx+fkLXdQwO\nDoqOZYnt7W0cHx+js7PTnAuHw3A4HAJTWefi4gLhcBj39/dQVRWapklzzLe1tYVEIgGbzYaNjQ30\n9PSIjmSZZDIJr9eLdDoNRVHQ0dEBn88nRVnt7+/D5/Ohq6vLnPN6vSV7QC0rFArQdR2ZTAaFQgFL\nS0uYmJgQHatsUpQzERFRNZHiWJuIiKiasJyJiIgqDMuZiIiowrCciYiIKgzLmYiIqMKwnImIiCoM\ny5mIiKjCsJyJiIgqzC8iivHPF8qqogAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "\u003cmatplotlib.figure.Figure at 0x7f7a18dfb8d0\u003e"
- ]
- },
- "metadata": {
- "tags": []
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# Plot the Data (Optional)\n",
+ "assert f(pi/2).numpy() == 1.0\n",
"\n",
- "import matplotlib.pyplot as plt\n",
"\n",
- "plt.scatter(inputs, labels)\n",
- "plt.show()"
+ "# grad_f will return a list of derivatives of f\n",
+ "# with respect to its arguments. Since f() has a single argument,\n",
+ "# grad_f will return a list with a single element.\n",
+ "grad_f = tfe.gradients_function(f)\n",
+ "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "JaFHyAG9nDET"
+ "id": "v9fPs8RyopCf"
},
"source": [
- "## Step 2: Define our TensorFlow variables\n",
+ "### Higher-order gradients\n",
"\n",
- "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias."
+ "The same API can be used to differentiate as many times as you like:\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
- "base_uri": "https://localhost:8080/",
- "height": 34
+ "height": 276
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 332,
+ "elapsed": 730,
"status": "ok",
- "timestamp": 1525154229931,
+ "timestamp": 1527005655565,
"user": {
"displayName": "",
"photoUrl": "",
@@ -190,54 +120,61 @@
},
"user_tz": 420
},
- "id": "z9r-ZeyrXu3A",
- "outputId": "e19a698e-5892-4fcd-80d3-1394605ee72c"
+ "id": "3D0ZvnGYo0rW",
+ "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff"
},
"outputs": [
{
"data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n",
"text/plain": [
- "[]"
+ "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e"
]
},
- "execution_count": 48,
"metadata": {
"tags": []
},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "# Create TensorFlow Variables using Keras's Dense layer.\n",
+ "def f(x):\n",
+ " return tf.square(tf.sin(x))\n",
"\n",
- "wb = tf.keras.layers.Dense(units=1, use_bias=True)\n",
+ "def grad(f):\n",
+ " return lambda x: tfe.gradients_function(f)(x)[0]\n",
"\n",
- "# We can access the underlying TensorFlow variables using wb.variables.\n",
- "# However, the variables won't exist until the dimensions of the input\n",
- "# tensors are known. Once the dimensions of the input tensors are known,\n",
- "# Keras can create and initialize the variables. Until then, Keras will\n",
- "# report the variables as an empty list: [].\n",
+ "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2Ï€ and +2Ï€\n",
"\n",
- "wb.variables"
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.plot(x, f(x), label=\"f\")\n",
+ "plt.plot(x, grad(f)(x), label=\"first derivative\")\n",
+ "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n",
+ "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n",
+ "plt.legend()\n",
+ "plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "docKLUaonYG_"
+ "id": "-39gouo7mtgu"
},
"source": [
- "## Step 3: *Define the loss function*\n",
+ "## Gradient tapes\n",
"\n",
- "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)."
+ "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n",
+ "\n",
+ "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n",
+ "\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
@@ -245,125 +182,42 @@
}
},
"colab_type": "code",
- "id": "0_w8ZJSCtuY7"
+ "id": "MH0UfjympWf7"
},
"outputs": [],
"source": [
- "def loss_fn(predictions, labels):\n",
- " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n",
- " return tf.reduce_mean(tf.square(predictions - labels))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "cellView": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "base_uri": "https://localhost:8080/",
- "height": 34
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 348,
- "status": "ok",
- "timestamp": 1525154234538,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "RkNbXoXkpjVH",
- "outputId": "e4688f3c-e29f-416d-f541-6d81953b5660"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: id=1252, shape=(), dtype=float32, numpy=16.979801\u003e"
- ]
- },
- "execution_count": 50,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Test loss function (optional).\n",
+ "def f(x, y):\n",
+ " output = 1\n",
+ " for i in range(y):\n",
+ " output = tf.multiply(output, x)\n",
+ " return output\n",
"\n",
- "loss_fn(wb(inputs), labels)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "base_uri": "https://localhost:8080/",
- "height": 51
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 418,
- "status": "ok",
- "timestamp": 1525154260083,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "K_7beXoHOU7t",
- "outputId": "8f55c028-fe2b-4edb-ad68-a849afc60623"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "w: -0.311619\n",
- "b: 0.000000\n"
- ]
- }
- ],
- "source": [
- "# At this point, the variables exist, and can now be queried:\n",
+ "def g(x, y):\n",
+ " # Return the gradient of `f` with respect to it's first parameter\n",
+ " return tfe.gradients_function(f)(x, y)[0]\n",
"\n",
- "w, b = wb.variables\n",
- "print(\"w: %f\" % w.numpy())\n",
- "print(\"b: %f\" % b.numpy())"
+ "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n",
+ "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n",
+ "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n",
+ "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "JVDWpL9VYWdP"
+ "id": "aNmR5-jhpX2t"
},
"source": [
- "## Step 4: Create an optimizer\n",
+ "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n",
"\n",
- "We'll use a `GradientDescentOptimizer` to fit our model."
+ "For example:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
@@ -371,36 +225,48 @@
}
},
"colab_type": "code",
- "id": "DudNEebMKDWN"
+ "id": "bAFeIE8EuVIq"
},
"outputs": [],
"source": [
- "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)"
+ "x = tf.ones((2, 2))\n",
+ " \n",
+ "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n",
+ "# a single t.gradient() call when the bug is resolved.\n",
+ "with tf.GradientTape(persistent=True) as t:\n",
+ " # TODO(ashankar): Explain with \"watch\" argument better?\n",
+ " t.watch(x)\n",
+ " y = tf.reduce_sum(x)\n",
+ " z = tf.multiply(y, y)\n",
+ "\n",
+ "# Use the same tape to compute the derivative of z with respect to the\n",
+ "# intermediate value y.\n",
+ "dz_dy = t.gradient(z, y)\n",
+ "assert dz_dy.numpy() == 8.0\n",
+ "\n",
+ "# Derivative of z with respect to the original input tensor x\n",
+ "dz_dx = t.gradient(z, x)\n",
+ "for i in [0, 1]:\n",
+ " for j in [0, 1]:\n",
+ " assert dz_dx[i][j].numpy() == 8.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "YBeJYxY8YaiO"
+ "id": "DK05KXrAAld3"
},
"source": [
- "### Step 5: Define a training step\n",
- "\n",
- "To fit model variables to the data we'll need to:\n",
+ "### Higher-order gradients\n",
"\n",
- "1. Calculate the gradients of the loss with respect to the model variables.\n",
- "2. Use `optimizer` to compute updates to the variable values based on those gradients.\n",
- "\n",
- "To calculate the gradients, we use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context manager\n",
- "and its `gradient` function to compute gradients through computation conducted within its context:\n"
+ "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
- "cellView": "code",
"colab": {
"autoexec": {
"startup": false,
@@ -408,163 +274,37 @@
}
},
"colab_type": "code",
- "id": "diDZfrMJM3OC"
+ "id": "cPQgthZ7ugRJ"
},
"outputs": [],
"source": [
- "def run_step(inputs, labels):\n",
- " with tf.GradientTape() as g:\n",
- " loss = loss_fn(wb(inputs), labels)\n",
- " # Compute the partial derivatives of loss with respect to the variables\n",
- " grads = g.gradient(loss, wb.variables)\n",
- " optimizer.apply_gradients(zip(grads, wb.variables))\n",
- " return loss"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "1WWepgmJQOzc"
- },
- "source": [
- "Repeatedly running the training step will nudge the variables towards the values that best fit the data (i.e., \"w\" will move closer to 3.0, while \"b\" will tend to 2.0):\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "base_uri": "https://localhost:8080/",
- "height": 51
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 380,
- "status": "ok",
- "timestamp": 1525154412590,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "ya5Qxz5XQlhU",
- "outputId": "8dd47155-a6c1-44c5-c279-617c803f1723"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Values of w, b BEFORE applying gradients: 2.725763, 1.894334\n",
- "Values of w, b AFTER applying gradients: 2.774932, 1.922555\n"
- ]
- }
- ],
- "source": [
- "w, b = wb.variables\n",
- "print(\"Values of w, b BEFORE applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n",
- "run_step(inputs, labels)\n",
- "print(\"Values of w, b AFTER applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n"
+ "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n",
+ "\n",
+ "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n",
+ "\n",
+ "with tf.GradientTape() as t:\n",
+ " with tf.GradientTape() as t2:\n",
+ " t2.watch(x)\n",
+ " y = x * x * x\n",
+ " # Compute the gradient inside the 't' context manager\n",
+ " # which means the gradient computation is differentiable as well.\n",
+ " dy_dx = t2.gradient(y, x)\n",
+ "d2y_dx2 = t.gradient(dy_dx, x)\n",
+ "\n",
+ "assert dy_dx.numpy() == 3.0\n",
+ "assert d2y_dx2.numpy() == 6.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "61TgeLVlKEQp"
- },
- "source": [
- "## Step 6: Create a training loop\n",
- "\n",
- "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- },
- "base_uri": "https://localhost:8080/",
- "height": 364
- },
- "colab_type": "code",
- "executionInfo": {
- "elapsed": 580,
- "status": "ok",
- "timestamp": 1525154278709,
- "user": {
- "displayName": "",
- "photoUrl": "",
- "userId": ""
- },
- "user_tz": 420
- },
- "id": "VukGe-huNaJ4",
- "outputId": "c79c8e63-c781-451e-f74f-20815d8da49f"
+ "id": "4U1KKzUpNl58"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[0.9409681558609009, 1.3733772039413452, 1.7128530740737915, 1.9793939590454102, 2.188689708709717, 2.3530514240264893, 2.4821391105651855, 2.583533763885498, 2.6631851196289062, 2.7257626056671143]\n"
- ]
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xd4U2X/BvD7ZLRpumlLS6EDgbKh\niIggU7aAgPhDRKsIUoYgiK++ioAguBARXmZBEARFUBGhiChIEQcqe+/RMlpGd9KRcX5/nDZtaFra\nkuY07f25rlw5zXmSfPMk5OY5Oec8giiKIoiIiMhhFHIXQEREVN0wfImIiByM4UtERORgDF8iIiIH\nY/gSERE5GMOXiIjIwVSOeJJbtzLs/pi+vlqkpOjt/rhkjf3sGOxnx2A/Owb7WRIQ4FnsOqcd+apU\nSrlLqBbYz47BfnYM9rNjsJ/vzWnDl4iIyFkxfImIiByM4UtERORgDF8iIiIHY/gSERE5GMOXiIjI\nwRi+REREDsbwJSIih/vxx61YtGi+3GXIhuFLRETkYA45vSQREZEtGzeux65dPwMAOnbsjOeeG45/\n/tmHFSuWwNVVA1/fGnjnndk4eHB/kdtUKueNMKesPDZ2C7p37wSNxkfuUoiInN6MGVOxdetmuz2e\nQiGgb98BmDFjdontbty4hgMH/sGKFV8AAKKjX0DXrt3x3XcbMH78q2jZshX27PkVaWmpNm/z8/O3\nW82O5nSbnTMy0jFixHN48cUX5S6FiIjuw9mzZ9G0aXOoVCqoVCo0b94S58+fRdeu3fHxxx/giy9W\noUGDhvDz87d5mzNzupGvp6cXOnTohF27duHkyRNo0qSp3CURETm1GTNm33OUWhYBAZ6lms1OEABR\nFC1/GwwGCIICvXv3Rdu27fDbb3H4739fxezZc2zeFhYWbreaHc3pRr4AEB09DgCwYsVSmSshIqLy\niohoiOPHj8FoNMJoNOLkyROIiGiI1as/g1KpwoABT6Jbt564fPmizducmdONfAGgR49eqFevHr79\ndgPefnsG/P2de/MDEVF1FBQUjFatHsKECdEwm0X07z8AQUG1EBgYhEmTxsHT0wuenp4YOvQ56PX6\nIrc5M0EsPOavIKXZ/FBW69d/jokTJ+LNN6di8uQ37P74JCnt5iO6P+xnx2A/Owb7WRIQ4FnsOqfc\n7AwAL774Ijw9vbBq1Qrk5ubKXQ4REVGpOW34enp6YtiwKNy8mYQfftgkdzlERESl5rThCwAvvTQa\nCoUCMTFL4ICt50RERHbh1OEbFhaO3r374ujRw/j7731yl0NERFQqTh2+ADB6tHTY0fLlS2SuhIiI\nqHScPnwfeaQ9mjdviR9/3Ir4+Ctyl0NERHRPTh++giAgOnoszGYzVq5cLnc5REQkk/Pnz1kGYe+8\n8xZycrLL/ViHDx9ESkqyvUorwunDFwAGDhyMgICa+PLLL5CZmSl3OUREJIM9e35FQkI8AGDmzA/g\n6qop92Nt27alQsPXKc9wdTdXV1e8+OJLmDPnfWzY8BVGjoyWuyQiIrqHYcMGY+3ajRBFEX36PIaF\nC5ehUaMmmDx5PN54420EBdWCyWTCnDnv4fr1azAajXjppTFo3boNtm+PxaZNG6FSqVG/fgQGDhyM\nH37YhD17foWvry+mT38LX3yxAZ9+Oge+vr44c+Y0UlNT8OyzL2Dbtq1IS0vFokXLIQjAzJlTkZWV\nhezsbLz66uvQ6TKxd28cLl26iNmz5+DMmZP4+ut1UCpVaNiwMSZMePW+X3uVCF8AeOGFkZg/fy5W\nrFiKF198CQpFlRjUExFVOPcZU+FqxykFoRDg3ncAdPeYrKFhw8a4ePECjEYDGjVqjOPHjyIiohGS\nk5MRFFQLAPDLLz/Bz88fb701HampqZg4cQzWrPkaX3+9DnPmzEdgYBC2bduCOnXqoG3bdujSpRua\nNGlm9TxKpQoLFizFzJlTcezYUSxYsASzZk3DwYP7ER5eF/36DUSnTl1w4MC/+PLLNXjvvY9Rv34E\nJk9+A15eXlizZiWWLfscLi4umDbtTRw9ehgtWkTeVxdVmfANCAjA4MFDsH79Ouza9TN69Ogtd0lE\nRFSCyMgHceLEMeTm5uCpp57Gnj270bLleURENLS0OX78KI4cOYSjRw8DAHJycmAwGNC9ey9MmfI6\nevXqg+7de5W4iblxY2n2Oz8/f8tMSL6+ftDpMlGjhh/WrPkM69evhcFggEZj/TiXLl1EUlIiJk8e\nDwDQ6TKRmJiIFi3u77VXmfAFgFGjxmL9+nWIiVnK8CUiKiXdjNn3HKWWRUCAJ3SlOLdzq1atsW7d\nauTkZKNfvwHYtm0rjh07ggcffMjSRqVS4/nnRxT5To+KehE9evRBXNxOvPLKWCxeXPwOt0ql0uay\nKIrYuPEr+PvXxLRps3D69EksWjTf6r5qtbSped68Rfd8PWVRpbbNNmvWHB06dMJvv+3GqVMn5S6H\niIhKEBoahqSkJGRm6qDVusPPzw9798ZZhW+TJs3w++97AAApKcmIiVkMs9mMmJjF8Pf3x9Chz6FZ\ns+ZITEyEIAgwmUxlqiEtLRW1a9cBAOzZsxtGoxEAoFAoYDKZEBoajsuXL1l2vlq5Mga3bt2879de\nqvA9e/YsunfvjnXr1gEAbty4gaioKAwbNgwTJ06sVBMbcK5fIiLn4evri6CgIABS0N64cQM1awZa\n1j/2WHe4uWkxZswIvPHGq2jRIhIKhQJarTtGj34REyeOhSAIaNAgAi1btsL8+R9j//5/Sv38vXv3\nxYYNX+LVV19G06bNcOfOHWzbtgWRkQ9i6tT/4vr1a5g48TX85z8TMXbsCKSlpcLfP+C+X/c9pxTU\n6/UYPXo0wsPD0bBhQzz33HN466230KlTJ/Tp0wfz5s1DUFAQhg0bVuxjVMTUUsVNWWUymdCu3YO4\nceM6Dh06xbl+7xOnBnMM9rNjsJ8dg/0sua8pBV1cXLBixQrUrFnTctvff/+Nbt26AQC6du2Kv/76\nyw5l2odSqcSoUWOQk5ODtWs/l7scIiKiIu4ZviqVqsjeX1lZWXBxcQEA+Pn54datWxVTXTk988xz\nnOuXiIgqrfve27k0U/n5+mqhUinv2a6sihvSBwR44qWXRuLTTz9FXNxPePbZZ+3+3NVJSZtOyH7Y\nz47BfnYM9nPJyhW+Wq0W2dnZ0Gg0SEpKstokbUtKir5cxZXkXr8pDBv2IhYsWIC5cz9Bjx79IQiC\n3WuoDvjbjWOwnx2D/ewY7GfJff3ma0v79u2xY8cOAMDPP/+Mjh07lq+yChQWFo5evR7H4cOH8M8/\nf8tdDhERkcU9w/f48eOIiorC999/jy+++AJRUVEYP348Nm/ejGHDhiE1NRUDBw50RK1lxrl+iYio\nMrrnoUb24MhDjQoTRRHdunXEyZPH8e+/RxESEmr3Oqo6bj5yDPazY7CfHcPe/RwXtwtdunSz2+M5\nit03OzsLzvVLROTcbty4jp07d8hdht1V6fAFgEGDnoK/fwDWrVvDuX6JiCqRYcMGw2QywWg0okeP\nTjh9Wjot8OTJ45GYeAMAMG/eRzh8+CA+/3wFVq6MwaxZ0zFu3EvYv/8fTJ36huWx+vaVRsaXLl3E\nK6+MwcSJY/HWW68hI6Nybumo8uGbP9dvenoaNmz4Su5yiIgqpRqtm9m8aAptNfQcN8pmG8/o4ZY2\nmrWrgfDwUj1n/pSC586dsUwpaDabraYUfOaZKERGPogXXxwFADAaDViy5LNip42dP/9jvP76FCxY\nsBRt2jyCTZs2lqs/KlqVD19AmutXOlPXUpjNZrnLISIiFEwpeOzYETz11NM4efIELlywnlLwbvnT\nAxbn5MkT+Oij2Rg/Pho7dvxomRChsqlSUwoWp2bNmnjyyf/D119/ybl+iYhsSD5w/J5tMpasuGeb\n7Kjh8Jw8AbDTlIJ3U6vVAFDk3A35sxFpNBosXBhT6c/tUC1GvoA01y8AxMRwtiMiosqgNFMK5k/t\ndzd3d3fcuXMbAHD+/Dno9dLJnOrXb4B9+/4EAOzcuaNMMxw5UrUJ3+bNW+DRRztyrl8iokrkXlMK\nhoXVxZkzp/G//31idb/69SOg0bhhzJgR2LHjRwQFBQMAJk78D9au/Rzjx0fjxx9jS9yELacqfZzv\n3bZv34YXXngGzz33AubNW2j3mqoiHhfpGOxnx2A/Owb7WVJtj/O9W8+evREWFo5vvvkat2/flrsc\nIiKqpqpV+HKuXyIiqgyqVfgC0ly/Hh6enOuXiIhkU+3C19PTC88+G4WkpERs2fK93OUQEVE1VO3C\nFwBGjhwNQRCwfPkSOGB/MyIiIivVMnzDw+uid+++OHz4EP79t3IeA0ZERFVXtQxfgHP9EhHJ6ccf\nt2LRovl2eSydLhP//LMPALB27WocP3603I+VmJiIkyfvfbav+1Vtw7ddu0fRrFkLxMb+gISEeLnL\nISKicjpz5rQlfKOihqNZsxblfqyDB//FqVMn7FVasarFuZ1tyZ/r95VXxmLVqhV4551ZcpdERFSt\n3LhxDf/5zyu4eTMJQ4YMQ79+A6zWf/fdRuzc+RMEQYGOHbvgmWeew9mzp/HJJx9BrVbDxcUFM2d+\ngHnz5kCv1yEkJBTHjx9Fly7dkJaWisOHDyI1NRWXLl1EdPRY7Ny5A5cvX8L06bPRtGkzLFw4DydP\nnkBubi4GDhyMDh06Y9Wq5VCpVAgMDELt2iH49NM5EAQBWq0WU6bMgKdn8SfOKItqG76ANNfvu+9O\nx7p1a/Daa/+Fh4eH3CURETncjBmu2LrVfnGgUAB9+7pixoycEtslJMRj1aovodNlYvjwYejb9wnL\nhAjXr19DXNwuLFmyEgAwduxIdO3aHT/+uBWDBj2F3r374sCBf5GcfAfDhkXh4sULGDDgSatNzgkJ\n8Viy5DNs3boZ69atxqpVX2L79q3YuXMH6tdvgKCgYEyYMBk5OdkYMmQg+vcfiD59+sHHxwcdOnTG\nxIlj8frrUxASEopNm77Bpk0b8cILI+3SR9U6fPPn+v344w+wceN6jBgxSu6SiIiqjRYtIqFSqeDt\n7QN3d3ekpaXBx8cHAHDq1AlcvZqACRNGAwD0eh0SE6+jQ4fOmDv3QyQkxKNbtx4ICwvHiRPHbD5+\no0ZNIAgC/Pz8Ua9eAyiVSvj6+kGnOwJXV1ekp6dhzJgRUKlUSE1NKXL//OkJAcBgMKBx4yZ2e+3V\nOnwBaa7fBQs+wYoVSzF8+MhiJ2gmIqqqZszIuecotSykczuX5vGsp/0rPAugSqVGu3aP4o033i5y\nr88++wJ//rkXs2fPwPjxk4p9dKVSaXNZFEUcOnQABw/ux6JF0mbmHj06Frl/RU5PWO2TJn+u3wsX\nzuPXX3+RuxwiomrjxImjMJlMSElJQVZWFry8vC3rGjZsjIMHDyA7OxuiKGL+/LnIycnGd99tQHp6\nGnr27IOnnx6Gs2dPQxAEm9MOliQtLRU1awZCpVLh99/3wGQyw2AwWE1hWJHTE1b7kS8gzfX79ddf\nIiZmCbp37yV3OURE1UJoaDimTXsT164lIDp6nNUIMygoCEOGPIOXXx4FhUKBTp26wNVVg9q1QzBt\n2pvw8PCAWq3GlCnvIDU1BcuWLURAQM1SP/dDD7XFl1+uwfjx0ejYsTPat++AuXM/QPfuPTF79gz4\n+Phi4sT/YM6c9/Dll2vg4uKKGTNm2+21V6spBUsyaFBf/PHHXvz2299o1Kix3R7X2XFqMMdgPzsG\n+9kx2M8STilYCtHR0kk3VqxYKnMlRERU1TF88xSe6/fOnTtyl0NERFUYwzdP/ly/2dnZnOuXiIgq\nFMO3EM71S0REjsDwLSR/rt/ExBvYunWz3OUQEVEVxfC9S/5cvzExiznXLxERVQiG71041y8RUcUr\nzZSCu3fvdFA1jsfwtYFz/RIRyW/dujVyl1BhGL42cK5fIqKKlz+l4PPPP43Y2B+s1n311Rc4f/4s\npkx5HQcP7scbb0zC+PHROH36FPr27WZpN3XqGzh4cD/0eh2mTn0DEyeOxfjx0Th//pyjX06ZMHxt\nyJ/r12w2Y9WqFXKXQ0RU4Vq3drd5WblSbWkzbpzGZpvoaI2lzdq1aoSHl+45ExLi8eGH87BwYQxW\nroyx2s9m2LDn4eHhgfff/xgAcOHCecybt6jYMxBu3Lgebdu2x4IFS/Haa29i0aJPy94JDsTwLcbA\ngYPh7x+AdevWIDMzU+5yiIiqHFtTChanfv0GcHFxKXb9sWNHsXnzdxg/PhqffPIhdLrK/b3NiRWK\nodFoMHz4SMyd+yHn+iWiKu/AAd092yxZkn3PNlFRBkyerMGtW6V51uKnFLybWq22ebvRaMxbr8Kr\nr76OZs1alOaJZceRbwleeGEkXFxcsGLFUpjNZrnLISKqUkqaUhAAzGbbh3sKgoDs7GxkZ2fj7Nkz\nAIAmTZrht9/iAACXLl3E11+vq9Da7xfDtwSBgYEYNOgpzvVLRFQB8qcUnDRpbJEpBQEgIqIhRo16\nvsj9Bg58CtHRL+D992eiYUPpN+Cnnnoa164lYNy4l/DRR7MRGfmgQ15DeXFKwXs4duwIunXriM6d\nu+Kbb3649x2qGE4N5hjsZ8dgPzsG+1nCKQXvQ/PmLdG+fQfs2bMbp0+fkrscIiKqAhi+pcC5fomI\nyJ4YvqXQq1cfhIZyrl8iIrIPhm8pSHP9jkZ2djbWrVstdzlEROTkGL6lNGxYFDw8PLFy5XIYDAa5\nyyEiIifG8C0lT08vDBv2HOf6JSKi+8bwLQPO9UtERPbA8C2DunUfQK9ej+PQoYPYv59z/RIRUfmU\nK3x1Oh3Gjx+PqKgoDB06FHv37rV3XZVWwVy/POyIiIjKp1zh+/3336Nu3bpYu3YtFixYgPfee8/e\ndVVa7dt3QNOmzREb+wOuXk2QuxwiInJC5QpfX19fpKamAgDS09Ph6+tr16IqM0EQMHr0OJhMJs71\nS0RE5VLuczuPHDkS8fHxSE9PR0xMDCIjI4ttazSaoFIpy11kZZOdnY2wsDDk5ubi6tWrcHd3l7sk\nIiJyIuWaz/eHH35AcHAwVq5cidOnT2PKlCnYtGlTse1TUvTlLrA4cp+4+/nnR2Du3A+xePFyvPji\nS7LVUdHk7ufqgv3sGOxnx2A/S+w+scLBgwfRoUMHAECjRo1w8+ZNmEym8lXnpDjXLxERlVe5wjcs\nLAxHjhwBAFy7dg3u7u5QKqvOZuXSyJ/r9/z5c9i9e6fc5RARkRMpV/g+/fTTuHbtGp577jm89tpr\nmDFjhp3Lcg7R0WMBADExS2SuhIiInEm5fvN1d3fHggUL7F2L08mf6zcu7lecPn0KjRo1lrskIiJy\nAjzD1X0qmOt3mcyVEBGRs2D43qeCuX7XIzmZc/0SEdG9MXzvU+G5fteuXS13OURE5AQYvnbAuX6J\niKgsGL52wLl+iYioLBi+dsK5fomIqLQYvnbCuX6JiKi0GL52xLl+iYioNBi+dsS5fomIqDQYvnbE\nuX6JiKg0GL52NnDgYPj7B2Dt2tXQ6XRyl0NERJUQw9fONBoNhg8fibS0VGzcuF7ucoiIqBJi+FYA\nzvVLREQlYfhWgMDAQAwcOBjnz59DXNwuucshIqJKhuFbQTjXLxERFYfhW0FatIhEu3aPYvfuXThz\n5rTc5RARUSXC8K1AnOuXiIhsYfhWoN69H0doaBjn+iUiIisM3wqkVCrx0kujkZWVhXXr1shdDhER\nVRIM3wo2bFgU3N09ONcvERFZMHwrmJeXN4YNew43blxHbOwPcpdDRESVAMPXAV56aQwEQcAHH8xC\nZmam3OUQEZHMGL4OULfuA3j55Ym4fPkSpk79r9zlEBGRzBi+DvLmm1PRokUkvvpqLbZu3Sx3OURE\nJCOGr4O4uLhg2bKVcHNzw+TJr+Datatyl0RERDJh+DpQ/foNMGvWh0hLS8X48aNhMpnkLomIiGTA\n8HWwqKjh6NOnH/74Yy8WL/6f3OUQEZEMGL4OJggC5s1biMDAIHz44SwcPnxQ7pKIiMjBGL4y8PPz\nw6JFMTAajRgzZiR0Op3cJRERkQMxfGXSuXNXjBv3Ci5evIBp096UuxwiInIghq+M3nprGpo1a4F1\n69Zg61ae/YqIqLpg+MrI1dXVcvjRa69NwPXr1+QuiYiIHIDhK7OIiIaYOfN9pKamYsKEMTCbzXKX\nREREFYzhWwm88MII9O79OPbu3YMlSxbKXQ4REVUwhm8lIB1+tAg1awbigw/exdGjh+UuiYiIKhDD\nt5Lw9/fHwoXLYDAYePgREVEVx/CtRLp27YbRo1/G+fPnMH36FLnLISKiCsLwrWSmTp2Bpk2bY+3a\nz/Hjj7Fyl0NERBWA4VvJ5B9+pNFoMHnyeCQm3pC7JCIisjOGbyXUsGEjzJjxHpKTk/Hyy6N5+BER\nURXD8K2kXnzxJfTs2Rt798Zh2bLFcpdDRER2xPCtpARBwKefLkZAQE28994MHDt2RO6SiIjIThi+\nlVhAQAAWLlxqOfxIr9fLXRIREdkBw7eSe+yxHoiOHotz587inXfelrscIiKyA4avE5g6dSYaN26K\nNWtW4qeffpS7HCIiuk/lDt8tW7bgiSeewJNPPom4uDg7lkR302g0WLZsJVxdXfHqqy8jKSlR7pKI\niOg+lCt8U1JSsHjxYnz11VdYtmwZdu3aZe+66C6NGzfBjBmzcefOHc5+RETk5MoVvn/99RfatWsH\nDw8P1KxZE7NmzbJ3XWTDiBHR6N69J+LifsXy5UvkLoeIiMqpXOF79epVZGdnY8yYMRg2bBj++usv\ne9dFNgiCgAULlsLfPwCzZ8/AsWNH5S6JiIjKQRBFUSzrnZYvX46DBw9i0aJFuH79Op5//nns3r0b\ngiDYbG80mqBSKe+7WJJs374djz/+OBo3boz9+/dDq9XKXRIREZWBqjx38vPzQ6tWraBSqRAaGgp3\nd3ckJyfDz8/PZvuUFPsfnxoQ4IlbtzLs/rjO4KGHOuCll0bjs89iMH78RHz00bwKe67q3M+OxH52\nDPazY7CfJQEBnsWuK9dm5w4dOmDfvn0wm81ISUmBXq+Hr69vuQuksps+fRYaN26Czz//DD//vF3u\ncoiIqAzKFb6BgYHo1asXhgwZglGjRmHq1KlQKHjIsCNpNBosXSodfjRx4jgkJSXJXRIREZVSuX7z\nLauK2PzAzRqSFSuW4u23/4uuXbth/frv7P6fIPazY7CfHYP97BjsZ4ndNztT5fHSS2Pw2GPdsXv3\nLnz22TK5yyEiolJg+Dq5gsOP/PHuu9Nx4sRxuUsiIqJ7YPhWAYGBgZg/fzFyc3MxduxIZGVlyV0S\nERGVgOFbRfTs2QcjRozC6dOn8O670+Quh4iISsDwrULeeWc2GjZshJUrl2Pnzh1yl0NERMVg+FYh\nbm5uWLZsFVxcXPDKK+Nw8+ZNuUsiIiIbGL5VTNOmzTBt2kzcvn0LEyeOhQOOJCMiojJi+FZBo0aN\nRZcuj2HXrl+wcmWM3OUQEdFdGL5VkEKhwMKFy+Dn54eZM6fh1KmTcpdERESFMHyrqMDAIHz66WLk\n5ORgzJgRyM7OlrskIiLKw/Ctwnr3fhzDh4/EqVMnMWvWdLnLISKiPAzfKm7GjPcQEdEQK1Ysw65d\nP8tdDhERgeFb5Wm1WixdutJy+NGtW7fkLomIqNpj+FYDzZu3wNtvz8CtWzcxadI4Hn5ERCQzhm81\nMXr0OHTu3BW//LIDq1atkLscIqJqjeFbTeQfflSjRg3MnDkVp0+fkrskIqJqi+FbjQQF1cKnny5G\ndnY2xowZycOPiIhkwvCtZvr06Yvnnx+BkyeP4733ZspdDhFRtcTwrYZmznwP9es3QEzMYuzevUvu\ncoiIqh2GbzXk7u6OZctWQq1WY8KEMbh9+7bcJRERVSsM32qqRYtIvPXWdNy8mYRXX32Zhx8RETkQ\nw7caGzduAjp27IIdO7Zj9eqVcpdDRFRtMHyrMYVCgUWLlsHX1xfvvDMFZ8+ekbskIqJqgeFbzdWq\nFYx58xYhOzsbo0ePQE5OjtwlERFVeQxfQt++/REVNRwnThzD+++/K3c5RERVHsOXAADvvvsB6tWr\nj6VLFyIu7le5yyEiqtIYvgSg4PAjlUqFCRPG4M6dO3KXRERUZTF8yaJly1Z4881pSEpKxKuvjufh\nR0REFYThS1bGj5+IDh064aeftuGLLz6XuxwioiqJ4UtWpMOPYuDj44Pp09/C6dOn5S6JiKjKYfhS\nEcHBtfHJJwuRlZWF/v374+LFC3KXRERUpTB8yab+/Qdg8uTXcf78eTz+eDf8/fc+uUsiIqoyGL5U\nrDffnIbly5cjLS0Ngwf3w/fffyt3SUREVQLDl0o0atQorF//HVxdNRg9egQ+/fRj7gVNRHSfGL50\nT126PIbY2J9Rp04IPvhgFiZNehm5ublyl0VE5LQYvlQqjRs3wfbtuxAZ2Qrr16/DM88MRlpaqtxl\nERE5JYYvlVpgYBC+//5H9O7dF3v37kHfvj1w5cplucsiInI6DF8qE3d3d3z++TqMGTMeZ8+eQZ8+\nj+HAgX/lLouIyKkwfKnMlEol3n33fXz44SdITk7GoEF9sXXrD3KXRUTkNBi+VG4jRozCunUboFSq\nMHJkFBYtWsA9oYmISoHhS/ele/de2LLlJ9SqFYx3352G//xnEgwGg9xlERFVagxfum/Nm7fATz/9\nimbNWmDt2s/x7LP/h/T0NLnLIiKqtBi+ZBe1agVjy5af0KNHL8TF/Yr+/Xvh6tUEucsiIqqUGL5k\nNx4eHlizZj1GjozGqVMn0bv3Yzhy5JDcZRERVToMX7IrlUqFDz6Yi9mzP8StWzcxYEAfbN++Te6y\niIgqFYYvVYjo6HFYvforAMDw4cMQE7OYe0ITEeW5r/DNzs5G9+7dsWnTJnvVQ1VInz598cMP2xEQ\nUBPTpr2FKVNeh9FolLssIiLZ3Vf4Ll26FN7e3vaqhaqgli1b4aeffkXjxk2xcuVyvPDCM8jMzJS7\nLCIiWZU7fC9cuIDz58+jS5cudiyHqqI6dUIQG7sDXbo8hl9+2YEnnuiNGzeuy10WEZFsBLGcP8RF\nR0dj2rRp2Lx5M2rXro0nn3xblRT0AAAgAElEQVSy2LZGowkqlbLcRVLVYDAYMH78eCxfvhy1a9dG\nbGwsIiMj5S6LiMjhVOW50+bNmxEZGYmQkJBStU9J0ZfnaUoUEOCJW7cy7P64ZM3e/Txr1seoVSsU\nM2dOxaOPdsBnn61G9+697Pb4zoqfZ8dgPzsG+1kSEOBZ7LpyhW9cXBwSEhIQFxeHxMREuLi4ICgo\nCO3bty93kVQ9CIKAl19+BaGhYXj55VF47rmn8f77H2PEiFFyl0ZE5DDlCt/58+dblhcuXIjatWsz\neKlM+vcfgODgYERFDcWbb76GS5cuYsaM2VAq+fMEEVV9PM6XZNO6dRts374LERENEROzGC+++Bx0\nOp3cZRERVbj7Dt8JEyaUuLMVUUnCwsKxbdsv6NixM376aRsGDnwcSUmJcpdFRFShOPIl2Xl7+2D9\n+u/wzDPP4ciRQ+jTpxtOnTopd1lERBWG4UuVgouLC+bPX4wpU6bj6tUE9OvXE7t375K7LCKiCsHw\npUpDEARMmvQfxMSsQm5uDoYNewpr166WuywiIrtj+FKlM2jQU/j2263w9vbGa6+9glmz3oHZbJa7\nLCIiu2H4UqXUtu0j+PHHXahXrz4WLvwUo0YNR1ZWltxlERHZBcOXKq0HHqiHH3/ciXbtHsXWrZvx\n5JN9cevWLbnLIiK6bwxfqtR8fWtg48bNeOqpp3HgwH706dMNZ8+ekbssIqL7wvClSs/V1RWLFy/H\n66+/hfj4y+jbtwd+//03ucsiIio3hi85BUEQ8Prrb2HRohjo9ToMGTIQX3/9pdxlERGVC8OXnMqQ\nIc/gm29+gIeHB155ZSw+/HAWyjkrJhGRbBi+5HTat++AH3/chbCwcMyb9zHGjh2J7OxsucsiIio1\nhi85pfr1G2D79l/Rpk1bbNr0Lf7v/wbgzp07cpdFRFQqDF9yWv7+/vjuu60YOPBJ/P33X3j88W64\nePG83GUREd0Tw5ecmkajwbJlqzBp0n9w6dJF9OnTjXtCE1Glx/Alp6dQKDBlynTMn78YGRkZePLJ\nfnjhhWE4ffqU3KUREdnE8KUqY9iwKGzZ8hPatGmL7dtj0aVLO0yYMAbx8VfkLo2IyArDl6qUhx56\nGLGxP2Pdug1o1KgJNmz4Cu3aPYgpU17HzZs35S6PiAgAw5eqIEEQ0LNnH/z66+9YuvQzBAfXxmef\nxeDhh1vigw/eRVpaqtwlElE1x/ClKkuhUGDw4CH4888DmDPnU3h6euLTT+eiTZsWWLhwPvR6vdwl\nElE1xfClKk+tVmP48JH4++/DmDp1JgBg1qzpaNs2EqtXr4TBYJC5QiKqbgTRAefmu3Urw+6PGdCm\nOUzmoqXrx72C7JHRAADPcaOg/vuvIm0MrR9CxvLVAADN2tXQzp9r8zmS/zoIuLhAee4svIc+abNN\nxryFMHTuCgDw6dUFitu3i7TJHvIM9P99GwDg/s7bcI39oUgbU2gY0r7fBgBw2b4NHlP/a/P5Urfu\ngDm4NoTUFPh262izjW7KdOQMHgIA8Hr2/6CysddvbtfuyJw7HwDgtnA+3FZ/VqSNqNVCdfoUbt3K\ngGr/P/AaPcLm86WvWgtjy1YAAN+2kRCMxiJtsqLHImv0ywAAj0kvw2XvniJtjM1bIn21dL5m16+/\nhPvHH9h8vuQ9+wAPDyguX4LP4P4222TOmYfcbj0BAD79ekJx47plndlsRkZ6OlZl6fG60Yjw8Lr4\nrmEjtDxxHBAEq8cx1wpGauzPAACXXT/D443JNp8v9butMIfXBTIzUaPzIzbb6F5/CzlDnwUAeA1/\nFqpjRyzrlAoBJrOI3I6dkTl/MQDALWYx3JYvLfI4okqFlL8PAwBURw7Ba0SUzedLj1kF40MPAwB8\nOz4MwcZIP2v4S8iaMAkA4PGfSXDZvbNIG2Ojxkj/8hsAgOt3G+H+/rs2ny9l116IPr5QXL8Gn/69\nbLbJnP0Rcvv0BQB4D+oLpY2d4XL6DYBu5nsAAO1H70GzcX2RNmZ/f6TuiAMAqPfshufkCTafL+3r\nTTA1iAByc1Gj3YOWfi5MP+k/yI4aDgDwjB4O9YH9RR7H0LYdMpasAABoVi6Hdsn/bD5f8oHjAADl\nyRPwjnraZpuMRTEwtHsUAODb9VEI6WlF2mQ/+zz0k98AALhPeR2uO7YXaWOqVx9pGzcDAFy2bobH\njKk2ny9l+68Qa9aEcPMmfPs8ZrNN5ozZyO0/EADgPWQglBeKHi+f06sPdO9/DADQzpsDzZdfFGkj\nenkjZfcfCAjwROqWn+A5frTN50tbuwGmJk0BADVaN7PZRs7vcnsJCPAsdp3Krs9E5AQUCgW8fXzw\nwpBncBoivvjic+y4fAmBajW8vX3g5uYmd4lEVMU578g3wLNCHpesVYd+vnLlMj7++AN8883XEEUR\nDz/8CKZOnYFHHmnvsBqqQz9XBuxnx2A/S0oa+fI3X6r2wsLCsWhRDPbs2Yc+ffrhn3/24YknemPo\n0CdxrNCmYSIie2H4EuVp1Kgx1qz5Ctu370KHDp3w66870a1bR0RHD+c5o4nIrhi+RHdp3boNvvtu\nKzZu3IzIyFbYvHkTHn20DV577RVcv35N7vKIqApg+BLZIAgCunR5DDt2xGHlyrV44IF6WLt2Ndq2\njcQ777yN5GROX0hE5cfwJSqBIAjo338A9uzZhwULliAgoCaWLl2Ihx5qgblzP0RmJncqIXJqZjOg\n0wGZmQ59Wu7tTCViP1vLycnBmjUrMX/+XNy+fRv+/v6YOPE1vPDCSGg0mnI/LvvZMdjPjnHf/SyK\ngMEAITsLQlYWoNdDyM6GkKWHkJUFITsL0GdJfxe6HdlZEPRZBW2yCrXRF2pT+PbsbOkpFQqkffUt\nDI91t1MvlLy3M8OXSsR+ti0zMwMxMUuwZMlCZGSko3btOnj99bcwZMgzUKnKfvg8+9kx2M92IoqA\nTgdBp4Ogy4Sg00Ghy4SgywR0OngrTMhISraEoJCVBdwVggXhWNBG0OuB/DA1mexbsloN0U0L0c0N\n0GggarUQNRrLbaKPL3RvvwNznRC7PSfDl8qN/Vyy5OQ7+N//PsWqVcuRnZ2NBg0i8Oab09Cv3xMQ\n7jpbVknYz45RLftZFKWQKxSUQmZmwbLuruXM/GUdhMyMguXC6/Q6CHaIDlEQADc3KfzcCsIQbm4Q\nNW4QtXnrNG557Qq3KRScGqkd7g5UjRugzbsux3+K7xfDl8qN/Vw6169fwyeffISvvloLk8mEyMhW\nmDLlHXTu3LVUIcx+dgyn6Oe8sFSkp0FIS4OQnlZ8GBYXmlZ/Z0Iwm8tfjiBAdPeA6O4uXTw8Cy17\nWK9zl9Z5BvkhzaQoCNG84LQEZn6AuroWOaVrVcLwpXJjP5fNhQvn8NFH72Hz5k0AgA4dOuHtt99B\n69ZtSrwf+9kxHNLPZjOEjPS84Ey3CtGC5XTp70LLVuttnB+9tEStuyUMzR6eQKHl/Nvh7gGzR35o\neuSFqDtErUeRUIVWW+aA5OdZwvClcmM/l8+xY0fw/vvvYteuXwAAvXv3xVtvTUPjxk1stmc/O0ap\n+jk3F0JaGhQZxQVlqnVopqVBUWhZyMwo8yZZUauF2csbopcXRC9vmL3zl32k68JBefeIs/Cy1h1Q\nyH8QCz/PEoYvlRv7+f7s2/cnZs+egX/+2QdBEPDUU0/jjTemICws3Kod+9nOjEYIKSlQJN+BIvkO\nhDvStWeuDvrrN4sP0Yx0aeegMhAFAaKnF0Rvbyk4ve5a9vKG6C2FqNnLO29d/rJ0O9TqCuoIefDz\nLGH4Urmxn++fKIrYuXMH3nvvXZw8eRxqtRpRUcPx6qtvIDAwEAD7uURmM4TUFCiSky0hWjhQFcl3\nICTfgeJO3nVyMhRpqaV+eFGthujtDbMlQPOCMi9ArZel8LSMUr29IXp4VorRZmXCz7OE4Uvlxn62\nH7PZjM2bv8OHH87G5cuXoNVqER09Di+//Arq1w+pHv0sitIo0yo4k6G4Yx2iluWUZAgpKaXaYUhU\nqWCu4QfRz0+6riFdm/1qWJa9wmsjRVRLI868wIWbW5Xe6UcO/N6QMHyp3NjP9mcwGPDVV2sxd+6H\nSEpKhI+PDyZNmoQePfqhQYMIucsrPVGU9qS1MRq1CtSUZGn9HWm5NDsTiQoFxBo1pPDMD1K/QoFa\no4YUsr41LIErenrdM0T5eXYM9rOE4Uvlxn6uOHq9HitXLsfChfOQmiptJm3YsBH69n0C/foNQNOm\nzcp0rLDdmM0Qbt+GMukGFEmJUCQlQZFYaDnphnR96yaE3NzSPaSvr1WIWo9M85Z9a0D0ywtTb58K\n2ZTLz7NjsJ8lDF8qN/ZzxcvISMcff/yKr77agLi4XcjOO91deHhd9Os3AP36PYFWrVrffxAbjVDc\nvpUXoolQJCbeFah5t926WeLZhUQXF5gDg2AOCIDZz79oiNbwsx61+vjIcoIDW/h5dgz2s4ThS+XG\nfnaM/H7OzMzErl0/IzZ2C375ZQf0eh0AIDi4Nvr27Y9+/Qbg4YcfgVKpLLizwQDFrZt5o9OkvBC9\nAcXNJOuQvX2rxN9ORY0G5ppBMAcGwhxUC6bAQClk8y9BtWAODIToW8NpfyPl59kx2M8Shi+VG/vZ\nMWz1c1ZqKvbH/oCD27bg8l9/wFuvRy0AdTUaNK/hh1C1Gp46HRR3bpd4XKmo1cJcMxCmoFp5IRpk\nFbJSuAZKm3qdNFRLi59nx2A/S0oK38qxLYiousnKgjIhHsqEK1DExwOpt+B58Yr1ZuDkZIQCePLu\n+2ZnA9evIQPARYUCBv8AuNVvAP9mzSEE15HC1TJaDZIOhanioUrkbBi+RBUhOxvKawlQXLmSF7Lx\nUMRflpbj46G4dbPIXfInJDR7ecMcGAhj0+Yw1wwsGK3mBaohoCb+jr+CH3b9jG3btuLGjevArZvw\nOHYUPXr0RL/QAXisVWu4u7s79jUTUalxszOViP1cjNxcKK4m5IXpFSjyrqWQvQJlUqLNu4lqNcy1\n68AUGg5TaCjMIaEwhYbBq2kE7rh6wRwYJJ1Lt5TMZjMOHtyP2NgtiI3dgvj4ywAANzc3dO3aHX37\n9kevXn3g5eVtj1ft9Ph5dgz2s6RCfvOdM2cODhw4AKPRiNGjR6Nnz57FtmX4Oq9q288GAxTXrlqP\nWuPzlhPiobhx3ebvrKJSCXPtEJhCpVA1h4TCFBIKU2g4zKGhUrgW3lkqjz36WRRFHD9+DNu2/YDY\n2C04e/YMAECtVqNTpy7o128AevfuCz8/v/t6HmdWbT/PDsZ+ltg9fPft24eVK1dixYoVSElJwaBB\ngxAXF1dse4av86qy/Ww0QnHjuvWoNX85IR6K69ds7hksKhTSyDUktFCwhsEcGibdViu4XIfVVEQ/\nnz17BrGxUhAfP34UAKBUKtG+fQf07fsEHn+8H4KCatn1OSu7Kvt5rmTYzxK7h6/JZEJOTg60Wi1M\nJhPat2+PP//80/rwh0IYvs7LafvZZIIi8YYUpFcuW0asls3E167aPJZVFASYawVbNgebQkKlYM1f\nDq5dISfBr+h+vnz5ErZt24rY2B9w4MC/AABBEPDQQw+jX78B6Nu3P0JDwyrs+SsLp/08Oxn2s6RC\nDzXasGED9u/fj48//rjYNhXxJrRp4wmzjZHJuHG5GDnSkLeswd9/F/0PQevWJixfLp3IYO1aNebP\nd7H5HH/9pYOLC3DunAJDh7rZbDNvXjY6d5a+xHv10uL27aJ7lQ4ZYsB//yudCeidd1wRG1t0ZBQa\nasb330uzqWzfrsLUqa42n2/rVj2Cg0WkpgLdutneoWbKlBwMHiydwu/ZZ91w+nTRMwV17WrE3Lk5\nAICFC12wenXRQNFqRZw+rcStWxnYv1+B0aNt98GqVVlo2VJ6L9q2dYetswdGR+di9GjpfZk0yRV7\n9xbtg+bNTVi9Wnpfvv5ahY8/tt0He/bo4OEBXL4sYPBADWA0QDAYAEPetdGIJRiLx02xAIAO2Iur\nqFPwAAoloFLh/8L3YUbfP2EOCcP033rhu31hgEpptWdwrVpmxMZK78uuXUq88YYGtnz3nR7h4SIy\nM4HOnW2/L6+/noOhQ6XOGT5cg2PHCj6bCoUCZrMZHTsaMX++9L7ExKixfHnRz6ZKBfz9t3T875Ej\nCowYYft9iYnJwkMPSe9Lx45a6PXS6zIaTcjK0iMrS4/c3PkQxTkAAD+/b2A0doebmxvUhf6D0aiR\nGV9+mZX3OlV4/33b78uuXTr4+ADXrwvo39/279azZ+egTx+pDwYNckN8fNHPZr9+RsycKfXBRx+5\nYOPGop9Nf38RO3boAQB79igxebLt9+Xrr7PQoIEZublAu3buln4ubNKkXERFSZ/N6GgNDhwo+p3R\ntq0JS5ZIn82VK9VYssT2d8aBA9L7cvKkAlFRtt+XRYuy0a6d9J3RtasW6elFvzOefdaAyZOl74wp\nU1yxY0fRfy/16pmxcaP0vmzdqsKMGbbfl+3b9ahZU8TNmwL69LH9vsyYkYP+/aX3ZcgQN1y4UPR9\n6dXLiPffl96XefNc8OWXRd8XLy8Ru3frERDgiS1b9Bg/3vb7snZtFpo0kd6H1q1t/3uR87vcXirs\nUKOdO3fi22+/xapVq0ps5+urhUple1R8PxQ2Tj/n6alBQID0hms0ts9Q5+qqQECAOq998WexCwjw\nhIsLcOdO8W18fLQICJCWVSrb7dzdXREQIP3D0Gptt1GrFZY3ytu7+Ofz8/NAQEDxzwUAXl5ulppc\nXGy3c3NzQUCA9EH18LDdJn9DRkCAJ3x9i38+X193y/MplYCt8zh4eNzP+yICJhOQKwVswKyp8Dh/\nGBnH9VCkfFv0gRQKCA3qA62GAuHhwDf1gMy8sywpVZZwVT05CO4fDJJqugkoDhV9qLK+L25uxbfx\n9Cx4X1xdi7ZTKBTQaEr3vuTXVJb3Jb+di4sCLi7e8Pb2xvPPT0OdOvWwadMm/PxzMkQxFWlpqVCr\n1dBqtXB3d4eLi9ryfF5exT+fv7/0OcnJKb6Nt3dBH6jVtttptQV94F7M9LQqVUEf+JRwJsoaNaQ+\nyM0taHP390bh7wxb7wsAaDSl/86Qntd+3xnFfaZcXBSlfF+kz6bZbL/vjNK9L9p7vi9ASf9e5Psu\nd4Ryj3z37t2LBQsW4LPPPoOPj0+JbbnZ2Xk5tJ+NRigvX4Ly7Bkoz52BKu9aee4cFLpMq6aiQgFz\naBiMEQ1himgkXTeIgKlBBEQn3LO3MnyeU1NTsGPHdmzbtgW7d+9CTo40yqlb9wHLaS4jIx+U53zT\ndlIZ+rk6qMh+FkXAaATyNnbBYBBgMEj/wTIagdxcoci6/PWF/zYYhEL3kf7TMXSoAV5e9qvV7pud\nMzIyMGzYMKxevbpUe04yfJ1XhfRzVhaU589Bde4MlGfPQHXurBSyFy8UOVG/6OICU70GBeEa0RDG\nBg1hqldf+u9wFVHZPs+ZmRnYuVM6zeXOnT9bTnNZu3Ydy2ku27RpW+x+HpVVZevnyshkks7jkpsr\nBVlODpCTI11b3ybdLt0GZGdLyzk5AtRqV6Sm5tgIwMLhd3cAFl1XEKrW7SrKypVZls3v9mD38N2w\nYQMWLlyIunXrWm776KOPEBwcbLM9w9d53U8/C2mpBeFaaDSriL9S5DAds4cnTBERMDXIC9eIhtLf\noeGV5qT8Fakyf56zsrIQF/crYmN/wI4d25GengYA8PX1RYsWkWjZshVatoxEixaRCA0Nq9Qj48rc\nz2YzkJUlhVh2trSclSXkBaFgFXZSEBaEoK1l27cVF6QFj2syyfP+qdUi1GppBKpSiXnXgIuLdHvB\npaBd/nL+xcVFzLtP0XVqtfVjFrST1nl5AW3amGwdCVhuPLczlds9+1kUobiZJIXr2TPSaDY/bG2c\naMLsH5A3ipXC1digIUwNG8EcVKtanwLRWT7Pubm5+P33PYiN3YK9e/fgypXLVut9fHzQvHkkWraM\ntARyeHjdShPIZelnUcwf0UmBqNcXBGN2tmAVjkWvC9rq9db3KQhX679zciq+jwRBhKurFDouLiI0\nGmnZ1VXMuw3QaETL+vy2rq62lotfX7OmFnq9/q4ALByu1gGoUlXNf/4MXyo3Sz+bzVDEX8nbVHy2\n0G+yZ6FISy1yP1NIKEwNCsLVmBe2om8NGV5F5eesn+fU1BQcPXoER44cxtGj0uXSpYtWbby9fdCi\nRUs0b97SEsrh4Q/Y3GGyLERRGqnpdAJ0uruvbd8mii5ISTFYQq9wgNoKUbPZ/ong5iaFnkYjXbu5\niXBzk/7Ovy58e36wFQ486/C71/qCx3BUyDnr59neGL5UJkJKMlRHDkN15BA8Lp6F4dgJqM6fhZA3\nz2w+UaWCqe4DeTs8ReSNZhvCWK+BtDsklVpV+jynpaXi2LGjeYF8CEePHsGFCxcAuAPwAOABrTYI\nDzzQEqGhTREcHIGaNetBqw1EVpaiSGjq9cWHqz02kapURcPPVgjmL+cHp5tb6f8uuL1qjvDuVpU+\nz/eDsxpRsYSMdKiOHoHq8CGojhyE+tBBKO/alKjSamGMaGS9w1NEQ5jqPlAhJ5ygysFgANLTBaSn\nAxkZAtLSBMvf6ekCMjKKjjCloNRCpwuGTtfHchtgnTh6PXD8uHQpLa1WhLu7CHd3oEYNs2XZ+rrk\n22rXdkdWViY0GunxNJpqsUsBVUL82FUnej1Ux49BfeQgVIcOQnXkEJTnz1nt/GSuUQO5XbvB0OpB\nGFs+CO9Oj+C2WwkHk1KlZDYDmZn54WkdmmlpUnCmp8OynB+sGRkFt+WflKOslEoRHh5S2Pn6iqhT\n5+5QlJbV6hxkZNzAnTuXkZh4Htevn8aNG+cgiukAMgFkws1NRNOmddGqVSO0bNkSLVu2Qv36Dcq9\nl3VAAHDrVoVv7CO6J4ZvVZWbC9XJ49KI9vBBqA8fgvLMKatTKpo9vWB4tCOMkQ/CENkKxsgHYQ4J\ntd4uFuAJcPORQ4mi9Ptj4dC0DkncFZgC0tJQaFkKUVEsW3hKe3yK8PQEgoLM8PIS8y4otFxwm6en\nCA8PEVqtdai6uJRl02qtvEs7AIBOp8Px48dw9Oghy+/IBw/GYf/+Xy330Gq1aNq0uWWHrpYtW6FB\ngwioOIQlJ8LffKsCoxHKM6ehPnIob0R7EKqTJ6yOmRXd3GBs3jJvRCsFremBevcc0bKf7092NpCc\nLFhd7tyRrlNSCv7OzFQhOdlsGZ0aDGULTkGQQlMKTxHe3sWHZuG/vb0L7uPmVjl/j9Tr9Thx4hiO\nHj2MI0eky9mzp2Eq9B9JNzc3NGnSLG+HrlZo0SISDRs2KhLI/Dw7BvtZwh2uqhKzGcoL56E6fNAy\nolUdPwohK8vSRHRxgbFps7wRrRS2poiGlWa2HWeVk1NykOYvF/67tJtu3dwALy9zMSPNoiHq7Y1C\nISsWeyrKqiorKwsnTx63jI6PHDmMM2dOwVjoxOIajQZNmzbL28taCuSOHR9Gamp2CY9M9sDvDQnD\n11mJIhRXLhca0R6C6shhKDILXreoVMLUqIlls7ExshWMjZtK2/7soKr2c04OLCPPwkFaeDR6d5Dq\ndKULUq1WRI0a0sXXV4SfX/F/+/lJt4WEVM1+dqTs7GycOnXCKpBPnz4Jg8FgaaNUKlG7dh3UqRNi\nuYSEhFqua9euA1dX2xMUUOlV1e+NsuLezs5AFKG4cd0SsurD0rUiJaWgiSDA1CACuS1bFWw+btZC\nGjZVczodkJQkIClJgdu3bQdp4dFpZmbpglSjkQLygQfMRYLz7kt+kPLtkIdGo0GrVq3RqlVry205\nOTk4ffqkZXP1hQtncOnSZfz11x8obtwRGBiUF8YhCAkJsyzXqSOFtIeHh6NeElVhHPnKRLh1C+rD\nB6x2iFLcumnVxhReN29E21oa0bZoCdGj+P9JVQS5+zkzsyBUExMFJCUJSExU5N0mWNZlZNw7TF1d\nC8KzpCAt3EZrewY2u5O7n6uL/H7Ozc3FtWtXcfVqAq5eTUBCQrxlOT4+HtevX7XahF2Yr6/vXaEs\nBbMU1qHw8fGtNGf0kgs/zxKOfOWWlQX1/n+gOrgf6vxDfK5dtWpiql0HOY/3LxjRtoyssmeDEkUp\nVPNDND9Uk5IKQjV/3b029fr7mxESYkZgoIigIBGBgWYEBBQfpNX8O5HyuLi4oG7dB1C37gM215tM\nJiQlJSIhIQFXr8YjISHesnz1agLOnTuDo0cP27yvu7tHoVCWRs/5f4eEhCIgoOZ9n92LnB/DtyIY\njVAdPgiXvXug3rsH6n//hpA3PRsgnd84p0cvy2+0hpYPQqxZU8aC7UMUgfR0WI1MExMVuHlTsBq1\n3rxZ8o5IgiCFZt26+aEqXdesWRCwQUEiAgJEe/20TWRFqVQiOLg2goNro23bR4qsF0URd+7cQULC\nlbyRc0Ewx8dL16dPn7L52K6urggOrm0VyvnBHBISilq1gnnYVDXAd9geRBHKUyfhsjdOCts//7Da\nKcrYtDlyO3aG4eFHYGz1IMzBtZ1qCCaKQGoqrDb9Wo9SC/7Ozi7+dSkUIvz9RdSrZ7aEaGCgaBWw\ngYFSqPLEWVSZCYIAf39/+Pv7W/3GXFh6elpeKCcgIeGKZVkaSSfgt99227yfUqlErVrBVjuF1ahR\nA76+NQpd+6FGjRrw8vLmKNpJ8TffclJcvpQ3so2Dy++/QXH7tmWdse4DMHTsgtxOnWFo3xGiv79s\ndZaG0Qhcvy4gPl6BhAQBV64oLMtJSSrcuCGWOOOKQiGNSvM3/dasab0ZWLqWgpf/obdN7s9zdVGZ\n+lmv1+Patat3/d58xbKcmHgDZrO5xMdQKBTw9fWFr68Uyn5+fpbl/KDOX65RI3+dL1wqeJNRZepn\nOfE3XzsQkpLg8ru0Gdnl99+gjL9iWWcKDEL2U08jt1MXGDp0grlOiIyVFmU2SzstXbkiBWp+sMbH\nSyF77Zpg8wT1CoWIWolNSOEAAAkTSURBVLWAJk3Md41SrUet/v6iXefAJKoOtFotGjSIQIMGETbX\nGwwGXL9+DdevX0NycjJSUpILXd+x+jslJRmXLl20OvFISTw8PG2MpouGduEwd3d3r/Y7ktkTw7cY\nQloq1H/+IY1s9+6B6sxpyzqztw9yHu8vbUru1AWm+g1k3YwsisDt24JVoMbHFyxfvSogN9d2fUFB\nZjz4oBmhoWaEhZkREiIiNFT6OzhYRHCwJ27d0jv4FRGRWq1GWFg4wsLCS9XebDYjPT3NKpALL9+5\nU/T2s2dPI6vQCXpK4uLiYmMUbTu869ULQW6uAHd3d2i17uU+F3dVxvDNl5UF9T/7LJuSVUcOQ8jb\n5CO6uSG3y2PI7dgFhk6dpWNrHfxhSk0FEhIUVqPXwiPY4nZg8vc3o2lTsyVQ88M1LMyM2rWlWV2I\nyPkpFAr4+PjCx8cXQL1S30+v1xcJ6sIj7Ltvv379Ok6dOlmm2jQaDdzd3eHu7gGtVpsXyh5511q4\nuxddzg/u/PvZauvMv3dX3/A1GqE6dMB6j+S8cyGLKhWMDz1sGdkaHnxImp26AmVmSuEaHy9YQjZ/\nOT5egfR02+Hq5SWdACI/WMPCCpZDQszg+QCIqCRarRZarRa1a9cp9X2MRiNSUlKKDe2srAwkJ6dC\np9PlXTKh1+uh0+mQlJQInU6H3ELnnr+/2u8O6sIhbyvIrf8TkL/s6+sLb2+f+66ptKpP+JrN1nsk\n//WnZY9kURBgbNYChg6dYOjUGblt28PeqZWTgyKbhfODNT5ewJ07tv8Hp9VKI9VHHpHCVBrBFmwa\n9va2a5lERPekUqkQEBCAgIAAm+tLs8OVwWCAXq+zBHTBcmbe33rLsvV66zDPb5OamorMzIxS/+59\nN4VCga+++haPPda9XPcvq6obvqJYsEfy73uK7pFcrz5yBg+R9kh+tCPEGn52eVq9Hjh/XoEzZxQ4\nezb/WonLlwWYzUVHry4uIkJCRDRvbiwSrKGh0vGu3MeBiKoatVoNb28fu442RVFEbm6uzXC2DvOi\n6wEgIqKh3Wq5lyoVvoqkRGlUm79HckK8ZZ2pVjCyhzyD3A6dYOjYGeYybGKxJTMTOHs2P2CVlqBN\nSBCKzKNao4YZbdqYUK+eFKjSCFbaRFyzplitZqMhIqoogiDA1dUVrq6uqGGnAVVFcerwFdJSof7j\nd2lT8u+/We+R7OuLnH4DpLDt1AWmevXLtUdyaipw5owS584VjGbPnlXg2rWiiVmzphkdOpgQEWFG\nRIQZDRtK1/7+FX4oNRERORHnC19RhNvC+cCOWPgdOFCwR7JWi9zHukt7JHfsJO2RXIYh5e3bQqHN\nxAWbjG/eLPoYwcFmdOlitISrdDHB19dur5KIiKow5wtfnQ7un3wIGI0wPPwIDB07S5cHH7rnHLai\nCNy8Kdz1e6x0sbXDU2ioGd27G/NGsdKItkEDM7y8KurFERFRdeB84evhgTv7j8M/LBBpetunXhNF\n6XSJ1qNY6XfZtDTrTc+CICI8XESbNgarzcX165vh7u6IF0RERNWN84UvADEgAHB3hzkzAwkJgtVe\nxfnLd09Fp1RKx8N26GC22lxcr56Zk58TEZFDOV34iiIwfbor/v0XOHXKA1lZ1iGrVouoX99cZKen\nBx4wc/o5IiKqFJwufPV6YMMGNbKzYQnZ/IBt2NCE8HDOnENERJWb08WUuztw4kQmAgM9kZzME/4T\nEZHzccrTO6jVDp/XgIiIyG6cMnyJiIicGcOXiIjIwRi+REREDsbwJSIicjCGLxERkYMxfImIiByM\n4UtERORgDF8iIiIHY/gSERE5GMOXiIjIwRi+REREDiaIoijKXQQREVF1wpEvERGRgzF8iYiIHIzh\nS0RE5GAMXyIiIgdj+BIRETkYw5eIiMjBnC5833//fTz99NMYOnQojh49Knc5VdqcOXPw9NNPY/Dg\nwfj555/lLqdKy87ORvfu3bFp0ya5S6mytmzZgieeeAJPPvkk4uLi5C6nStLpdBg/fjyioqIwdOhQ\n7N27V+6SKi2V3AWUxT///IMrV65gw4YNuHDhAqZMmYINGzbIXVaVtG/fPpw7dw4bNmxASkoKBg0a\nhJ49e8pdVpW1dOlSeHv/f3v398r6H8Bx/LkzubBxzDJaIblRSigXWHJBLlz7kRa3cqVc0FKUq7lS\nKAp/gLZwI0pZuZgr5UJRXGExy8evxgU6d6fOt9x8a3vbp9fjbrt61i5ee38+n7bfpjNsy7IslpaW\niEajpNNpFhYW6OjoMJ1lO5ubm1RXVzM+Ps7d3R3Dw8Ps7u6azvqRcmp84/E4nZ2dANTU1PD09MTr\n6ytut9twmf00NzdTX18PQFFREW9vb3x+fuJ0Og2X2c/l5SUXFxcagwyKx+O0tLTgdrtxu93Mzs6a\nTrIlj8fD+fk5AM/Pz3g8HsNFP1dOXXZOpVL/fJglJSXc398bLLIvp9NJQUEBAJFIhPb2dg1vhoTD\nYSYnJ01n2Nr19TXv7++MjIwwODhIPB43nWRLPT09JBIJurq6CAaDTExMmE76sXLq5Ptf+mXMzNvf\n3ycSibC+vm46xZa2trZoaGigoqLCdIrtPT4+sri4SCKRYGhoiIODAxwOh+ksW9ne3sbv97O2tsbZ\n2RmhUEjPMXwjp8bX5/ORSqX+vk4mk5SWlhossrfDw0OWl5dZXV2lsLDQdI4txWIxrq6uiMVi3N7e\nkp+fT3l5Oa2trabTbMXr9dLY2EheXh6VlZW4XC4eHh7wer2m02zl+PiYQCAAQG1tLclkUrervpFT\nl53b2trY29sD4PT0FJ/Pp/u9GfLy8sLc3BwrKysUFxebzrGt+fl5otEoGxsb9Pb2Mjo6quHNgEAg\nwNHREV9fX1iWRTqd1v3IDKiqquLk5ASAm5sbXC6XhvcbOXXybWpqoq6ujoGBARwOB9PT06aTbGtn\nZwfLshgbG/v7Xjgcxu/3G6wS+X/Kysro7u6mr68PgKmpKX79yqmzR07o7+8nFAoRDAb5+PhgZmbG\ndNKPpb8UFBERyTJ99RMREckyja+IiEiWaXxFRESyTOMrIiKSZRpfERGRLNP4ioiIZJnGV0REJMs0\nviIiIln2BzQKNGAGnBgwAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "\u003cmatplotlib.figure.Figure at 0x7f7a18df6b50\u003e"
- ]
- },
- "metadata": {
- "tags": []
- },
- "output_type": "display_data"
- }
- ],
"source": [
- "# Train our variables.\n",
- "\n",
- "# numpy is used for its asscalar() function.\n",
- "import numpy as np\n",
- "\n",
- "num_training_steps = 10\n",
- "\n",
- "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n",
- " loss_at_step = []\n",
- " w_at_step = []\n",
- " b_at_step = []\n",
- " for step_num in range(num_training_steps):\n",
- " loss_at_step.append(run_step(inputs, labels))\n",
- " w, b = wb.variables\n",
- " w_at_step.append(np.asscalar(w.numpy()))\n",
- " b_at_step.append(np.asscalar(b.numpy()))\n",
- "\n",
- " print(w_at_step)\n",
- " t = range(0, num_training_steps)\n",
- " plt.plot(t, loss_at_step, 'k',\n",
- " t, w_at_step, 'r',\n",
- " t, [true_w] * num_training_steps, 'r--',\n",
- " t, b_at_step, 'b',\n",
- " t, [true_b] * num_training_steps, 'b--')\n",
- " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n",
- " plt.show()\n",
+ "## Next Steps\n",
"\n",
- "train_model(inputs, labels, wb, optimizer, num_training_steps)"
+ "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)."
]
}
],
@@ -572,7 +312,7 @@
"colab": {
"collapsed_sections": [],
"default_view": {},
- "name": "Eager Execution Tutorial: Working with Gradients",
+ "name": "Automatic Differentiation",
"provenance": [],
"version": "0.3.2",
"views": {}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb
new file mode 100644
index 0000000000..d9a9bffbb4
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb
@@ -0,0 +1,443 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "k2o3TTG4TFpt"
+ },
+ "source": [
+ "# Training Models\n",
+ "\n",
+ "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n",
+ "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n",
+ "\n",
+ "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "3LXMVuV0VhDr"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "PJ64L90aVir3"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "tfe = tf.contrib.eager # Shorthand for some symbols"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eMAWbDJFVmMk"
+ },
+ "source": [
+ "## Variables\n",
+ "\n",
+ "Neural networks are characterized by a set of parameters (sometimes called \"weights\", sometimes called \"variables\") with fixed shapes and types, where the actual values are computed and adjusted during the training process. The `tfe.Variable` object encapsulates such parameters.\n",
+ "\n",
+ "Recall that `Tensor` objects are immutable, i.e., the underlying value of the `Tensor` cannot be changed. `Variable` objects act like `Tensor`s but are mutable via calls to `assign`, `assign_add` etc.\n",
+ "\n",
+ "For example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "itxmrMil6DQi"
+ },
+ "outputs": [],
+ "source": [
+ "v = tfe.Variable(1.0)\n",
+ "assert v.numpy() == 1.0\n",
+ "\n",
+ "# Re-assign the value\n",
+ "v.assign(3.0)\n",
+ "assert v.numpy() == 3.0\n",
+ "\n",
+ "# Use `v` in a TensorFlow operation like tf.square() and reassign\n",
+ "v.assign(tf.square(v))\n",
+ "assert v.numpy() == 9.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BMiFcDzE7Qu3"
+ },
+ "source": [
+ "## Example: Fitting a linear model\n",
+ "\n",
+ "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n",
+ "\n",
+ "1. Define the model.\n",
+ "2. Define a loss function.\n",
+ "3. Obtain training data.\n",
+ "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n",
+ "\n",
+ "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "gFzH64Jn9PIm"
+ },
+ "source": [
+ "### Define the model\n",
+ "\n",
+ "Let's define a simple class to encapsulate the variables and the computation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "_WRu7Pze7wk8"
+ },
+ "outputs": [],
+ "source": [
+ "class Model(object):\n",
+ " def __init__(self):\n",
+ " # Initialize variable to (5.0, 0.0)\n",
+ " # In practice, these should be initialized to random values.\n",
+ " self.W = tfe.Variable(5.0)\n",
+ " self.b = tfe.Variable(0.0)\n",
+ " \n",
+ " def __call__(self, x):\n",
+ " return self.W * x + self.b\n",
+ " \n",
+ "model = Model()\n",
+ "\n",
+ "assert model(3.0).numpy() == 15.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xa6j_yXa-j79"
+ },
+ "source": [
+ "### Define a loss function\n",
+ "\n",
+ "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Y0ysUFGY924U"
+ },
+ "outputs": [],
+ "source": [
+ "def loss(predicted_y, desired_y):\n",
+ " return tf.reduce_mean(tf.square(predicted_y - desired_y))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qutT_fkl_CBc"
+ },
+ "source": [
+ "### Obtain training data\n",
+ "\n",
+ "Let's synthesize the training data with some noise."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "gxPTb-kt_N5m"
+ },
+ "outputs": [],
+ "source": [
+ "TRUE_W = 3.0\n",
+ "TRUE_b = 2.0\n",
+ "NUM_EXAMPLES = 1000\n",
+ "\n",
+ "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n",
+ "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n",
+ "outputs = inputs * TRUE_W + TRUE_b + noise"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-50nq-wPBsAW"
+ },
+ "source": [
+ "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 293
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 1210,
+ "status": "ok",
+ "timestamp": 1527005898290,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "_eb83LtrB4nt",
+ "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Current loss: 9.48636\n"
+ ]
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.scatter(inputs, outputs, c='b')\n",
+ "plt.scatter(inputs, model(inputs), c='r')\n",
+ "plt.show()\n",
+ "\n",
+ "print('Current loss: '),\n",
+ "print(loss(model(inputs), outputs).numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "sSDP-yeq_4jE"
+ },
+ "source": [
+ "### Define a training loop\n",
+ "\n",
+ "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "MBIACgdnA55X"
+ },
+ "outputs": [],
+ "source": [
+ "def train(model, inputs, outputs, learning_rate):\n",
+ " with tf.GradientTape() as t:\n",
+ " current_loss = loss(model(inputs), outputs)\n",
+ " dW, db = t.gradient(current_loss, [model.W, model.b])\n",
+ " model.W.assign_sub(learning_rate * dW)\n",
+ " model.b.assign_sub(learning_rate * db)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "RwWPaJryD2aN"
+ },
+ "source": [
+ "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 446
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 569,
+ "status": "ok",
+ "timestamp": 1527005915434,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "XdfkR223D9dW",
+ "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 0: W=5.00 b=0.00, loss=9.48636\n",
+ "Epoch 1: W=4.58 b=0.42, loss=6.28101\n",
+ "Epoch 2: W=4.24 b=0.76, loss=4.29357\n",
+ "Epoch 3: W=3.98 b=1.02, loss=3.06128\n",
+ "Epoch 4: W=3.78 b=1.23, loss=2.29721\n",
+ "Epoch 5: W=3.61 b=1.39, loss=1.82345\n",
+ "Epoch 6: W=3.49 b=1.52, loss=1.52970\n",
+ "Epoch 7: W=3.38 b=1.62, loss=1.34756\n",
+ "Epoch 8: W=3.30 b=1.70, loss=1.23463\n",
+ "Epoch 9: W=3.24 b=1.76, loss=1.16460\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model = Model()\n",
+ "\n",
+ "# Collect the history of W-values and b-values to plot later\n",
+ "Ws, bs = [], []\n",
+ "epochs = range(10)\n",
+ "for epoch in epochs:\n",
+ " Ws.append(model.W.numpy())\n",
+ " bs.append(model.b.numpy())\n",
+ " current_loss = loss(model(inputs), outputs)\n",
+ "\n",
+ " train(model, inputs, outputs, learning_rate=0.1)\n",
+ " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n",
+ " (epoch, Ws[-1], bs[-1], current_loss))\n",
+ "\n",
+ "# Let's plot it all\n",
+ "plt.plot(epochs, Ws, 'r',\n",
+ " epochs, bs, 'b')\n",
+ "plt.plot([TRUE_W] * len(epochs), 'r--',\n",
+ " [TRUE_b] * len(epochs), 'b--')\n",
+ "plt.legend(['W', 'b', 'true W', 'true_b'])\n",
+ "plt.show()\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "vPnIVuaSJwWz"
+ },
+ "source": [
+ "## Next Steps\n",
+ "\n",
+ "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n",
+ "\n",
+ "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n",
+ "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n",
+ "\n",
+ "The [next tutorial](TODO) will cover these higher level APIs."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "name": "Training Models",
+ "provenance": [],
+ "version": "0.3.2",
+ "views": {}
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index 2d51cfdeee..b14ef1df8f 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -49,15 +49,17 @@ def random_batch(batch_size, data_format):
return images, one_hot
-def train_one_step(model, images, labels, optimizer):
-
+def compute_gradients(model, images, labels):
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
- grads = tape.gradient(loss, model.variables)
- optimizer.apply_gradients(zip(grads, model.variables))
+ return tape.gradient(loss, model.variables)
+
+
+def apply_gradients(model, optimizer, gradients):
+ optimizer.apply_gradients(zip(gradients, model.variables))
class ResNet50Test(tf.test.TestCase):
@@ -114,7 +116,8 @@ class ResNet50Test(tf.test.TestCase):
with tf.device(device), tfe.execution_mode(execution_mode):
optimizer = tf.train.GradientDescentOptimizer(0.1)
images, labels = random_batch(2, data_format)
- train_one_step(model, images, labels, optimizer)
+ apply_gradients(model, optimizer,
+ compute_gradients(model, images, labels))
self.assertEqual(320, len(model.variables))
tfe.async_wait()
events = summary_test_util.events_from_logdir(logdir)
@@ -138,14 +141,16 @@ class ResNet50Test(tf.test.TestCase):
# garbage to be collected. The hope is that this is a build-only effect,
# and a subsequent training loop will create nothing which needs to be
# collected.
- train_one_step(model, images, labels, optimizer)
+ apply_gradients(model, optimizer,
+ compute_gradients(model, images, labels))
gc.collect()
previous_gc_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
for _ in range(2):
# Run twice to ensure that garbage that is created on the first
# iteration is no longer accessible.
- train_one_step(model, images, labels, optimizer)
+ apply_gradients(model, optimizer,
+ compute_gradients(model, images, labels))
gc.collect()
# There should be no garbage requiring collection.
self.assertEqual(0, len(gc.garbage))
@@ -180,9 +185,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return (16, 32, 64)
if tf.DeviceSpec.from_string(device.name).device_type == 'TPU':
- # TODO(iga): Training fails with batch size of 16, probably because of
- # no layout optimizations with op-by-op mode. Investigate more.
- return (8,)
+ return (32,)
return (16, 32)
def _report(self, label, start, num_iters, device, batch_size, data_format):
@@ -248,18 +251,21 @@ class ResNet50Benchmarks(tf.test.Benchmark):
device, data_format = device_and_format
for batch_size in self._train_batch_sizes():
(images, labels) = random_batch(batch_size, data_format)
- num_burn = 3
- num_iters = 10
model = resnet50.ResNet50(data_format)
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ apply_grads = apply_gradients
if defun:
model.call = tfe.defun(model.call, compiled=compiled)
- optimizer = tf.train.GradientDescentOptimizer(0.1)
+ apply_grads = tfe.defun(apply_gradients, compiled=compiled)
+ num_burn = 3
+ num_iters = 10
with tf.device(device):
iterator = make_iterator((images, labels))
for _ in xrange(num_burn):
(images, labels) = iterator.next()
- train_one_step(model, images, labels, optimizer)
+ apply_grads(model, optimizer,
+ compute_gradients(model, images, labels))
if execution_mode:
tfe.async_wait()
self._force_device_sync()
@@ -268,7 +274,8 @@ class ResNet50Benchmarks(tf.test.Benchmark):
start = time.time()
for _ in xrange(num_iters):
(images, labels) = iterator.next()
- train_one_step(model, images, labels, optimizer)
+ apply_grads(model, optimizer,
+ compute_gradients(model, images, labels))
if execution_mode:
tfe.async_wait()
self._force_device_sync()
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
index 492adbe1d8..5ee2176154 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
@@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model):
self.label_dimension = label_dimension
self.keep_prob = keep_prob
- self.cells = self._add_cells(
+ self.cells = tf.contrib.checkpoint.List(
[tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes])
self.relu = layers.Dense(
label_dimension, activation=tf.nn.relu, name="relu")
@@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model):
hidden_states = tf.gather_nd(chars, indices)
return self.relu(hidden_states)
- def _add_cells(self, cells):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for i, c in enumerate(cells):
- setattr(self, "cell-%d" % i, c)
- return cells
-
def loss(labels, predictions):
"""Computes mean squared loss."""
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index 74701b2f4f..c2340a293a 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -50,7 +50,7 @@ class RNN(tf.keras.Model):
def __init__(self, hidden_dim, num_layers, keep_ratio):
super(RNN, self).__init__()
self.keep_ratio = keep_ratio
- self.cells = self._add_cells([
+ self.cells = tf.contrib.checkpoint.List([
tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
for _ in range(num_layers)
])
@@ -74,14 +74,6 @@ class RNN(tf.keras.Model):
# tuple (output, output_states).
return [input_seq]
- def _add_cells(self, cells):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for i, c in enumerate(cells):
- setattr(self, "cell-%d" % i, c)
- return cells
-
class Embedding(layers.Layer):
"""An Embedding layer."""
diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py
index 4032e755f6..90a3711475 100644
--- a/tensorflow/contrib/eager/python/saver_test.py
+++ b/tensorflow/contrib/eager/python/saver_test.py
@@ -60,15 +60,9 @@ class SaverTest(test.TestCase):
def testSameNameNoClobbering(self):
with ops.device(self._dev()):
- # Note that this test purposefully uses Graphs rather than
- # IsolateTest. Users are more likely to accidentally create the same
- # variable name this way.
- first_graph = ops.Graph()
- with first_graph.as_default():
- v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
- with ops.Graph().as_default():
- v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
- saver = _saver.Saver([v1_first_graph, v1_second_graph])
+ v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
+ v2 = resource_variable_ops.ResourceVariable(2.0, name='v1')
+ saver = _saver.Saver([v1, v2])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'v1'):
saver.save(ckpt_prefix)
@@ -126,12 +120,11 @@ class SaverTest(test.TestCase):
saver = _saver.Saver([v1])
saver.save(ckpt_prefix)
- with ops.Graph().as_default():
- saver = _saver.Saver([v1])
- with _saver.restore_variables_on_create(ckpt_prefix):
- # Value is from checkpoint, but not from argument.
- ret, _ = model(2.0)
- self.assertEqual(ret.numpy(), 1.0)
+ saver = _saver.Saver([v1])
+ with _saver.restore_variables_on_create(ckpt_prefix):
+ # Value is from checkpoint, but not from argument.
+ ret, _ = model(2.0)
+ self.assertEqual(ret.numpy(), 1.0)
def testRestoreNotFound(self):
with ops.device(self._dev()):
@@ -184,17 +177,17 @@ class SaverTest(test.TestCase):
4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
# reset the graph and reload on create, so that 1 + 2 = 3
- with ops.Graph().as_default():
- with _saver.restore_variables_on_create(ckpt_prefix):
- @graph_callable.graph_callable(
- [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
- def model2(x):
- v = variable_scope.get_variable(
- 'v', initializer=init_ops.zeros_initializer(), shape=())
- return v + x
-
- self.assertEqual(
- 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
+ ops.reset_default_graph()
+ with _saver.restore_variables_on_create(ckpt_prefix):
+ @graph_callable.graph_callable(
+ [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
+ def model2(x):
+ v = variable_scope.get_variable(
+ 'v', initializer=init_ops.zeros_initializer(), shape=())
+ return v + x
+
+ self.assertEqual(
+ 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
class GetOptimizerTests(test.TestCase):
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 4808b9ee30..ddd6aa442f 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -72,7 +72,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
estimator: A `tf.estimator.Estimator` instance to call evaluate.
input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A
function that constructs the input data for evaluation.
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 2f3e57653c..b7194ae333 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -2022,6 +2022,7 @@ class GDN(base.Layer):
def beta_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
+ pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype)
return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal)
def gamma_initializer(shape, dtype=None, partition_info=None):
@@ -2029,6 +2030,7 @@ class GDN(base.Layer):
assert len(shape) == 2
assert shape[0] == shape[1]
eye = linalg_ops.eye(shape[0], dtype=dtype)
+ pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
beta = self.add_variable(
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 8ed9f446bc..0e35b1aa8b 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
__all__ = ["rev_block", "RevBlock", "recompute_grad"]
@@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
`variable_scope(name, use_resource=True), which are the default in Eager mode
and when running on TPU.
+ Warning: Because the function will be called again on the backwards pass, the
+ user should be careful to not use ops in their function that mutate state or
+ have randomness (for example, batch normalization or dropout). If the function
+ does have such operations, it is recommended that the function take the
+ `is_recomputing` keyword argument which will be `False` on the forward pass
+ and `True` on the backwards pass so that it can disable state changes when
+ `is_recomputing=True` (for example, not updating the moving averages in batch
+ normalization).
+
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
@@ -482,6 +492,7 @@ def _is_on_tpu():
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
+ has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
for arg in args:
if not isinstance(arg, framework_ops.Tensor):
raise ValueError("All inputs to function must be Tensors")
@@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
with backprop.GradientTape() as tape:
- outputs = fn(*args)
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = False
+ outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
# Backward pass
@@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
with contrib_framework_ops.arg_scope(arg_scope):
with variable_scope.variable_scope(vs, reuse=True):
with backprop.GradientTape() as tape:
- outputs = fn(*inputs)
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = fn(*inputs, **fn_kwargs)
recompute_vars = set(tape.watched_variables())
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index 997f53b9e1..bc09ba8d43 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.layers import rev_block_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import convolutional
from tensorflow.python.layers import core as core_layers
+from tensorflow.python.layers import normalization as normalization_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase):
for grad in grads:
self.assertTrue(grad is not None)
+ def testWithIsRecomputeKwarg(self):
+
+ kwarg_values = []
+
+ @rev_block_lib.recompute_grad
+ def layer_with_recompute(inputs, is_recomputing=False):
+ kwarg_values.append(is_recomputing)
+ out = core_layers.dense(inputs, 2)
+ out = normalization_layers.batch_normalization(out, training=True)
+ if is_recomputing:
+ # Ensure that the updates are not duplicated by popping off the latest
+ # 2 additions.
+ update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS)
+ update_ops.pop()
+ update_ops.pop()
+ return out
+
+ x = array_ops.ones((2, 4), dtypes.float32)
+ with variable_scope.variable_scope("layer1", use_resource=True):
+ y = layer_with_recompute(x)
+ loss = math_ops.reduce_sum(y)
+ tvars = variables.trainable_variables()
+ gradients_impl.gradients(loss, [x] + tvars)
+
+ update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ self.assertEqual(2, len(update_ops))
+ self.assertEqual([False, True], kwarg_values)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 70b70af98c..e100bc7a1e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -31,7 +31,6 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
-from tensorflow.python.training import training_util
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training as train
+from tensorflow.python.training import training_util
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params):
parent_scope = "linear"
with variable_scope.variable_scope(
- values=features.values(), name_or_scope=parent_scope) as scope:
+ values=features.values(),
+ name_or_scope=parent_scope,
+ partitioner=optimizer.partitioner) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index 0a863f0e20..597ca4e86d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.training import ftrl
from tensorflow.python.training import input as input_lib
@@ -966,6 +967,63 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearClassifier with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ classifier = linear.LinearClassifier(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ print('all scores = {}'.format(scores))
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testEval(self):
"""Tests that eval produces correct metrics.
"""
@@ -1540,6 +1598,60 @@ class LinearRegressorTest(test.TestCase):
loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss, 0.05)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearRegressor with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([0.6, 0.8, 0.3]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id', symmetric_l2_regularization=1.0,
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ regressor = linear.LinearRegressor(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
"""Tests LinearClassifier with SDCAOptimizer and sparse features."""
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index b5741967ab..d0c32b43cc 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -35,6 +35,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import googletest
@@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender):
+def make_variable_dict(max_age, max_gender, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ partitioner = None
+ if partitioned:
+ partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
+ axis=0)
+ with variable_scope.variable_scope(
+ name_or_scope='variables',
+ partitioner=partitioner):
+ age_weights = variables_lib.Variable(
+ array_ops.zeros(
+ [max_age + 1], dtype=dtypes.float32))
+ gender_weights = variables_lib.Variable(
+ array_ops.zeros(
+ [max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -265,6 +274,54 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testPartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1, partitioned=True)
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.411608 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index f980746a19..0047d5753a 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -22,12 +22,14 @@ import collections
from six.moves import range
from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.ops import internal_convert_to_tensor
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_sdca_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -43,9 +45,6 @@ __all__ = ['SdcaModel']
class SdcaModel(object):
"""Stochastic dual coordinate ascent solver for linear models.
- This class currently only supports a single machine (multi-threaded)
- implementation. We expect the weights and duals to fit in a single machine.
-
Loss functions supported:
* Binary logistic loss
@@ -182,18 +181,41 @@ class SdcaModel(object):
# TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic.
def _create_slots(self):
- # Make internal variables which have the updates before applying L1
- # regularization.
+ """Make unshrinked internal variables (slots)."""
+ # Unshrinked variables have the updates before applying L1 regularization.
+ # Each unshrinked slot variable is either a `Variable` or list of
+ # `Variable`, depending on the value of its corresponding primary variable.
+ # We avoid using `PartitionedVariable` for the unshrinked slots since we do
+ # not need any of the extra information.
self._slots = collections.defaultdict(list)
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
- with ops.device(var.device):
- # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is
- # fixed
- self._slots['unshrinked_' + name].append(
- var_ops.Variable(
- array_ops.zeros_like(var.initialized_value(), dtypes.float32),
- name=var.op.name + '_unshrinked/SDCAOptimizer'))
+ # Our primary variable may be either a PartitionedVariable, or a list
+ # of Variables (each representing a partition).
+ if (isinstance(var, var_ops.PartitionedVariable) or
+ isinstance(var, list)):
+ var_list = []
+ # pylint: disable=protected-access
+ for v in var:
+ with ops.colocate_with(v):
+ # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109
+ # is fixed.
+ slot_var = var_ops.Variable(
+ initial_value=array_ops.zeros_like(v.initialized_value(),
+ dtypes.float32),
+ name=v.op.name + '_unshrinked/SDCAOptimizer')
+ var_list.append(slot_var)
+ self._slots['unshrinked_' + name].append(var_list)
+ # pylint: enable=protected-access
+ else:
+ with ops.device(var.device):
+ # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is
+ # fixed.
+ self._slots['unshrinked_' + name].append(
+ var_ops.Variable(
+ array_ops.zeros_like(var.initialized_value(),
+ dtypes.float32),
+ name=var.op.name + '_unshrinked/SDCAOptimizer'))
def _assertSpecified(self, items, check_in):
for x in items:
@@ -205,16 +227,25 @@ class SdcaModel(object):
if not isinstance(check_in[x], list):
raise ValueError(x + ' must be a list.')
+ def _var_to_list(self, var):
+ """Wraps var in a list if it is not a list or PartitionedVariable."""
+ if not (isinstance(var, list) or
+ isinstance(var, var_ops.PartitionedVariable)):
+ var = [var]
+ return var
+
def _l1_loss(self):
"""Computes the (un-normalized) l1 loss of the model."""
with name_scope('sdca/l1_loss'):
sums = []
for name in ['sparse_features_weights', 'dense_features_weights']:
- for weights in self._convert_n_to_tensor(self._variables[name]):
- with ops.device(weights.device):
- sums.append(
- math_ops.reduce_sum(
- math_ops.abs(math_ops.cast(weights, dtypes.float64))))
+ for var in self._variables[name]:
+ for v in self._var_to_list(var):
+ weights = internal_convert_to_tensor(v)
+ with ops.device(weights.device):
+ sums.append(
+ math_ops.reduce_sum(
+ math_ops.abs(math_ops.cast(weights, dtypes.float64))))
# SDCA L1 regularization cost is: l1 * sum(|weights|)
return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums)
@@ -223,17 +254,37 @@ class SdcaModel(object):
with name_scope('sdca/l2_loss'):
sums = []
for name in ['sparse_features_weights', 'dense_features_weights']:
- for weights in self._convert_n_to_tensor(self._variables[name]):
- with ops.device(weights.device):
- sums.append(
- math_ops.reduce_sum(
- math_ops.square(math_ops.cast(weights, dtypes.float64))))
+ for var in self._variables[name]:
+ for v in self._var_to_list(var):
+ weights = internal_convert_to_tensor(v)
+ with ops.device(weights.device):
+ sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast(
+ weights, dtypes.float64))))
# SDCA L2 regularization cost is: l2 * sum(weights^2) / 2
return l2 * math_ops.add_n(sums) / 2.0
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
- return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list]
+ # input_list can be a list of Variables (that are implicitly partitioned),
+ # in which case the underlying logic in internal_convert_to_tensor will not
+ # concatenate the partitions together. This method takes care of the
+ # concatenating (we only allow partitioning on the first axis).
+ output_list = []
+ for x in input_list:
+ tensor_to_convert = x
+ if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable):
+ # We only allow for partitioning on the first axis.
+ tensor_to_convert = array_ops.concat(x, axis=0)
+ output_list.append(internal_convert_to_tensor(
+ tensor_to_convert, as_ref=as_ref))
+ return output_list
+
+ def _get_first_dimension_size_statically(self, w, num_partitions):
+ """Compute the static size of the first dimension for a sharded variable."""
+ dim_0_size = w[0].get_shape()[0]
+ for p in range(1, num_partitions):
+ dim_0_size += w[p].get_shape()[0]
+ return dim_0_size
def _linear_predictions(self, examples):
"""Returns predictions of the form w*x."""
@@ -286,6 +337,28 @@ class SdcaModel(object):
result = math_ops.sigmoid(result)
return result
+ def _get_partitioned_update_ops(self,
+ v_num,
+ num_partitions_by_var,
+ p_assignments_by_var,
+ gather_ids_by_var,
+ weights,
+ full_update,
+ p_assignments,
+ num_partitions):
+ """Get updates for partitioned variables."""
+ num_partitions = num_partitions_by_var[v_num]
+ p_assignments = p_assignments_by_var[v_num]
+ gather_ids = gather_ids_by_var[v_num]
+ updates = data_flow_ops.dynamic_partition(
+ full_update, p_assignments, num_partitions)
+ update_ops = []
+ for p in range(num_partitions):
+ with ops.colocate_with(weights[p]):
+ result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p])
+ update_ops.append(result)
+ return update_ops
+
def minimize(self, global_step=None, name=None):
"""Add operations to train a linear model by minimizing the loss function.
@@ -318,18 +391,89 @@ class SdcaModel(object):
# Solver returns example_state_update, new delta sparse_feature_weights
# and delta dense_feature_weights.
- weights_tensor = self._convert_n_to_tensor(self._slots[
- 'unshrinked_sparse_features_weights'])
sparse_weights = []
sparse_indices = []
- for w, i in zip(weights_tensor, sparse_feature_indices):
- # Find the feature ids to lookup in the variables.
- with ops.device(w.device):
- sparse_indices.append(
- math_ops.cast(
- array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
- dtypes.int64))
- sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))
+ # If we have partitioned variables, keep a few lists of Tensors around
+ # that we need for the assign_add after the op call to
+ # gen_sdca_ops.sdca_optimizer().
+ num_partitions_by_var = []
+ p_assignments_by_var = []
+ gather_ids_by_var = []
+ for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_feature_indices):
+ # Append the sparse_indices (in full-variable space).
+ sparse_idx = math_ops.cast(
+ array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
+ dtypes.int64)
+ sparse_indices.append(sparse_idx)
+ if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable):
+ num_partitions = len(w)
+ flat_ids = array_ops.reshape(sparse_idx, [-1])
+ # We use div partitioning, which is easiest to support downstream.
+ # Compute num_total_ids as the sum of dim-0 of w, then assign
+ # to partitions based on a constant number of ids per partition.
+ # Optimize if we already know the full shape statically.
+ dim_0_size = self._get_first_dimension_size_statically(
+ w, num_partitions)
+
+ if dim_0_size.value:
+ num_total_ids = constant_op.constant(dim_0_size.value,
+ flat_ids.dtype)
+ else:
+ dim_0_sizes = []
+ for p in range(num_partitions):
+ if w[p].get_shape()[0].value is not None:
+ dim_0_sizes.append(w[p].get_shape()[0].value)
+ else:
+ with ops.colocate_with(w[p]):
+ dim_0_sizes.append(array_ops.shape(w[p])[0])
+ num_total_ids = math_ops.reduce_sum(
+ math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
+ ids_per_partition = num_total_ids // num_partitions
+ extras = num_total_ids % num_partitions
+
+ p_assignments = math_ops.maximum(
+ flat_ids // (ids_per_partition + 1),
+ (flat_ids - extras) // ids_per_partition)
+
+ # Emulate a conditional using a boolean indicator tensor
+ new_ids = array_ops.where(p_assignments < extras,
+ flat_ids % (ids_per_partition + 1),
+ (flat_ids - extras) % ids_per_partition)
+
+ # Cast partition assignments to int32 for use in dynamic_partition.
+ # There really should not be more than 2^32 partitions.
+ p_assignments = math_ops.cast(p_assignments, dtypes.int32)
+ # Partition list of ids based on assignments into num_partitions
+ # separate lists.
+ gather_ids = data_flow_ops.dynamic_partition(new_ids,
+ p_assignments,
+ num_partitions)
+ # Append these to the lists for use in the later update.
+ num_partitions_by_var.append(num_partitions)
+ p_assignments_by_var.append(p_assignments)
+ gather_ids_by_var.append(gather_ids)
+
+ # Gather the weights from each partition.
+ partition_gathered_weights = []
+ for p in range(num_partitions):
+ with ops.colocate_with(w[p]):
+ partition_gathered_weights.append(
+ array_ops.gather(w[p], gather_ids[p]))
+
+ # Stitch the weights back together in the same order they were before
+ # we dynamic_partitioned them.
+ condition_indices = data_flow_ops.dynamic_partition(
+ math_ops.range(array_ops.shape(new_ids)[0]),
+ p_assignments, num_partitions)
+ batch_gathered_weights = data_flow_ops.dynamic_stitch(
+ condition_indices, partition_gathered_weights)
+ else:
+ w_as_tensor = internal_convert_to_tensor(w)
+ with ops.device(w_as_tensor.device):
+ batch_gathered_weights = array_ops.gather(
+ w_as_tensor, sparse_idx)
+ sparse_weights.append(batch_gathered_weights)
# pylint: disable=protected-access
esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
@@ -355,12 +499,25 @@ class SdcaModel(object):
with ops.control_dependencies([esu]):
update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
# Update the weights before the proximal step.
- for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'],
- sparse_indices, sfw):
- update_ops.append(state_ops.scatter_add(w, i, u))
+ for v_num, (w, i, u) in enumerate(
+ zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_indices, sfw)):
+ if (isinstance(w, var_ops.PartitionedVariable) or
+ isinstance(w, list)):
+ update_ops += self._get_partitioned_update_ops(
+ v_num, num_partitions_by_var, p_assignments_by_var,
+ gather_ids_by_var, w, u, p_assignments, num_partitions)
+ else:
+ update_ops.append(state_ops.scatter_add(w, i, u))
for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
- update_ops.append(w.assign_add(u))
-
+ if (isinstance(w, var_ops.PartitionedVariable) or
+ isinstance(w, list)):
+ split_updates = array_ops.split(
+ u, num_or_size_splits=[v.shape.as_list()[0] for v in w])
+ for v, split_update in zip(w, split_updates):
+ update_ops.append(state_ops.assign_add(v, split_update))
+ else:
+ update_ops.append(state_ops.assign_add(w, u))
if not global_step:
return control_flow_ops.group(*update_ops)
with ops.control_dependencies(update_ops):
@@ -385,21 +542,22 @@ class SdcaModel(object):
for name in ['sparse_features_weights', 'dense_features_weights']:
for var, slot_var in zip(self._variables[name],
self._slots['unshrinked_' + name]):
- update_ops.append(var.assign(slot_var))
+ for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)):
+ update_ops.append(v.assign(sv))
# Apply proximal step.
with ops.control_dependencies(update_ops):
update_ops = []
for name in ['sparse_features_weights', 'dense_features_weights']:
for var in self._variables[name]:
- with ops.device(var.device):
- # pylint: disable=protected-access
- update_ops.append(
- gen_sdca_ops.sdca_shrink_l1(
- self._convert_n_to_tensor(
- [var], as_ref=True),
- l1=self._symmetric_l1_regularization(),
- l2=self._symmetric_l2_regularization()))
+ for v in self._var_to_list(var):
+ with ops.device(v.device):
+ # pylint: disable=protected-access
+ update_ops.append(
+ gen_sdca_ops.sdca_shrink_l1(
+ self._convert_n_to_tensor([v], as_ref=True),
+ l1=self._symmetric_l1_regularization(),
+ l2=self._symmetric_l2_regularization()))
return control_flow_ops.group(*update_ops)
def approximate_duality_gap(self):
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
index d4e54c82f9..200e7de6b9 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py
@@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None):
num_loss_partitions = params["num_loss_partitions"]
weight_column_name = params["weight_column_name"]
update_weights_hook = params.get("update_weights_hook", None)
+ partitioner = params["partitioner"]
loss_type = None
if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
@@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None):
example_id_column=example_id_column,
num_loss_partitions=n_loss_partitions,
symmetric_l1_regularization=l1_regularization,
- symmetric_l2_regularization=l2_regularization)
+ symmetric_l2_regularization=l2_regularization,
+ partitioner=partitioner)
parent_scope = "linear"
with variable_scope.variable_scope(
- values=features.values(), name_or_scope=parent_scope) as scope:
+ values=features.values(), name_or_scope=parent_scope,
+ partitioner=partitioner) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
@@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `_SDCAEstimator` estimator object.
Args:
@@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `_SDCAEstimator` estimator.
@@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator):
"l2_regularization": l2_regularization,
"weight_column_name": weight_column_name,
"update_weights_hook": _SdcaUpdateWeightsHook(),
+ "partitioner": partitioner,
}
super(_SDCAEstimator, self).__init__(
@@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `SDCALogisticClassifier` object.
Args:
@@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `SDCALogisiticClassifier` estimator.
@@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator):
l2_regularization=l2_regularization,
num_loss_partitions=num_loss_partitions,
config=config,
- feature_engineering_fn=None)
+ feature_engineering_fn=None,
+ partitioner=partitioner)
def predict_classes(self, input_fn=None):
"""Runs inference to determine the predicted class.
@@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator):
l2_regularization=1.0,
num_loss_partitions=None,
config=None,
- feature_engineering_fn=None):
+ feature_engineering_fn=None,
+ partitioner=None):
"""Construct a `SDCALinearRegressor` estimator object.
@@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator):
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
+ partitioner: Variable partitioner for the primal weights (`div`
+ partitioning strategy will be used).
Returns:
A `SDCALinearRegressor` estimator.
@@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator):
l2_regularization=l2_regularization,
num_loss_partitions=num_loss_partitions,
config=config,
- feature_engineering_fn=None)
+ feature_engineering_fn=None,
+ partitioner=partitioner)
def predict_scores(self, input_fn):
"""Returns predicted scores for given features.
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
index bed3d5139f..6476671882 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
@@ -273,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase):
metrics = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(metrics['accuracy'], 0.9)
+ def testPartitionedMixedFeatures(self):
+ """Tests SDCALogisticClassifier with a mix of features (partitioned)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([900.0, 700.0, 600.0]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ with self._single_threaded_test_session():
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ classifier = sdca_estimator.SDCALogisticClassifier(
+ example_id_column='example_id',
+ feature_columns=[
+ price, sq_footage_bucket, country, sq_footage_country
+ ],
+ weight_column_name='weights',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ classifier.fit(input_fn=input_fn, steps=50)
+ metrics = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(metrics['accuracy'], 0.9)
+
class SDCALinearRegressorTest(test.TestCase):
@@ -350,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase):
loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss, 0.05)
+ def testMixedFeaturesArbitraryWeightsPartitioned(self):
+ """Tests SDCALinearRegressor works with a mix of features (partitioned)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ with self._single_threaded_test_session():
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ regressor = sdca_estimator.SDCALinearRegressor(
+ example_id_column='example_id',
+ feature_columns=[
+ price, sq_footage_bucket, country, sq_footage_country
+ ],
+ l2_regularization=1.0,
+ weight_column_name='weights',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
"""SDCALinearRegressor works with sparse features and L1 regularization."""
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 12039ecc6f..9872c6f97c 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -64,7 +64,8 @@ class SDCAOptimizer(object):
of workers running the train steps. It defaults to 1 (single machine).
`num_table_shards` defines the number of shards for the internal state
table, typically set to match the number of parameter servers for large
- data sets.
+ data sets. You can also specify a `partitioner` object to partition the primal
+ weights during training (`div` partitioning strategy will be used).
"""
def __init__(self,
@@ -73,13 +74,15 @@ class SDCAOptimizer(object):
num_table_shards=None,
symmetric_l1_regularization=0.0,
symmetric_l2_regularization=1.0,
- adaptive=True):
+ adaptive=True,
+ partitioner=None):
self._example_id_column = example_id_column
self._num_loss_partitions = num_loss_partitions
self._num_table_shards = num_table_shards
self._symmetric_l1_regularization = symmetric_l1_regularization
self._symmetric_l2_regularization = symmetric_l2_regularization
self._adaptive = adaptive
+ self._partitioner = partitioner
def get_name(self):
return 'SDCAOptimizer'
@@ -108,6 +111,10 @@ class SDCAOptimizer(object):
def adaptive(self):
return self._adaptive
+ @property
+ def partitioner(self):
+ return self._partitioner
+
def get_train_step(self, columns_to_variables, weight_column_name, loss_type,
features, targets, global_step):
"""Returns the training operation of an SdcaModel optimizer."""
@@ -175,10 +182,12 @@ class SDCAOptimizer(object):
sparse_feature_column = _dense_tensor_to_sparse_feature_column(
dense_bucket_tensor)
sparse_feature_with_values.append(sparse_feature_column)
- # For bucketized columns, the variables list contains exactly one
- # element.
- sparse_feature_with_values_weights.append(
- columns_to_variables[column][0])
+ # If a partitioner was used during variable creation, we will have a
+ # list of Variables here larger than 1.
+ vars_to_append = columns_to_variables[column][0]
+ if len(columns_to_variables[column]) > 1:
+ vars_to_append = columns_to_variables[column]
+ sparse_feature_with_values_weights.append(vars_to_append)
elif isinstance(
column,
(
@@ -226,8 +235,12 @@ class SDCAOptimizer(object):
array_ops.shape(ids)[0]), [-1])
sparse_feature_with_values.append(
SparseFeatureColumn(example_ids_filtered, reproject_ids, weights))
- sparse_feature_with_values_weights.append(
- columns_to_variables[column][0])
+ # If a partitioner was used during variable creation, we will have a
+ # list of Variables here larger than 1.
+ vars_to_append = columns_to_variables[column][0]
+ if len(columns_to_variables[column]) > 1:
+ vars_to_append = columns_to_variables[column]
+ sparse_feature_with_values_weights.append(vars_to_append)
else:
raise ValueError('SDCAOptimizer does not support column type %s.' %
type(column).__name__)
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 9bfc0a0fbe..c8820ab29b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -212,12 +212,13 @@ def generated_test_models():
"global_batch_norm",
"greater",
"greater_equal",
- "l2_pool",
"l2norm",
+ "l2_pool",
"less",
"less_equal",
"local_response_norm",
"log_softmax",
+ "lstm",
"max_pool",
"maximum",
"mean",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 7e285186f4..24a9b0f6b8 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -99,4 +99,3 @@ typedef enum {
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
-}
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 0b35a220e7..ee42e5cdc8 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -254,6 +254,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
+ TF_LITE_ENSURE(context, real_multiplier < 1.0);
QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
&data->output_shift);
CalculateActivationRangeUint8(params->activation, output,
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index abb2549f85..a308de055f 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -151,8 +151,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
&data->output_activation_max);
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index 1439c8bce1..c00cafb9fb 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -47,12 +47,6 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
- if (input.type != TensorType_FLOAT32) {
- // The following is required by quantized inference. It is the unittest's
- // responsibility to make sure the output scale falls into the correct
- // range.
- CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
- }
int input_depth = GetShape(input_)[3];
int output_depth = GetShape(filter_)[3];
@@ -176,6 +170,43 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
+TEST(QuantizedDepthwiseConvolutionOpTest,
+ SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+ QuantizedDepthwiseConvolutionOpModel quant_op(
+ {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
+ {TensorType_UINT8, {}, -127, 128});
+ DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_FLOAT32, {1, 2, 2, 4}},
+ {TensorType_FLOAT32, {}});
+
+ std::initializer_list<float> input = {
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ };
+ std::initializer_list<float> filter = {
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ };
+ std::initializer_list<float> bias = {1, 2, 3, 4};
+
+ quant_op.SetInput(input);
+ quant_op.SetFilter(filter);
+ quant_op.SetBias(bias);
+ quant_op.Invoke();
+
+ float_op.SetInput(input);
+ float_op.SetFilter(filter);
+ float_op.SetBias(bias);
+ float_op.Invoke();
+
+ EXPECT_THAT(quant_op.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 3374923e6e..989920622d 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -101,6 +101,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_size *= input->dims->data[i];
}
+ TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
const int batch_size = input_size / filter->dims->data[1];
const int num_units = filter->dims->data[0];
@@ -109,8 +110,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
- TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
-
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
TfLiteType data_type = input->type;
@@ -118,6 +117,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
+ TF_LITE_ENSURE(context, real_multiplier < 1.0);
QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
&data->output_shift);
CalculateActivationRangeUint8(params->activation, output,
@@ -218,11 +218,8 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
}
- // TODO(mirkov): change std::minmax_element with a vectorized call.
- auto minmax_element =
- std::minmax_element(input->data.f, input->data.f + total_input_size);
// Save matrix multiplication computation for all zero input.
- if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) {
+ if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) {
tensor_utils::ApplyActivationToVector(output->data.f,
batch_size * num_units,
params->activation, output->data.f);
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index aabbb0685c..0a5223b235 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -420,6 +420,15 @@ cc_library(
}),
)
+cc_library(
+ name = "test_util",
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ ":types",
+ ],
+)
+
cc_test(
name = "tensor_utils_test",
srcs = ["tensor_utils_test.cc"],
@@ -440,6 +449,83 @@ cc_test(
],
)
+cc_test(
+ name = "depthwiseconv_float_test",
+ srcs = ["depthwiseconv_float_test.cc"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "depthwiseconv_quantized_test",
+ srcs = ["depthwiseconv_quantized_test.cc"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "resize_bilinear_float_test",
+ srcs = ["resize_bilinear_float_test.cc"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ ":test_util",
+ ":types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "softmax_quantized_test",
+ timeout = "long",
+ srcs = [
+ "softmax_quantized_test.cc",
+ ],
+ deps = [
+ ":optimized_base",
+ ":quantization_util",
+ ":reference_base",
+ ":test_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "logsoftmax_quantized_test",
+ timeout = "long",
+ srcs = [
+ "logsoftmax_quantized_test.cc",
+ ],
+ tags = ["tflite_not_portable"],
+ deps = [
+ ":optimized_base",
+ ":quantization_util",
+ ":reference_base",
+ ":test_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "log_quantized_test",
+ srcs = ["log_quantized_test.cc"],
+ deps = [
+ ":optimized_base",
+ ":reference_base",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
cc_library(
name = "cpu_check",
hdrs = [
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
new file mode 100644
index 0000000000..844ee6a53d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -0,0 +1,162 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+
+namespace tflite {
+namespace {
+
+// Runs the DepthwiseConv and compares against the reference implementation.
+template <FusedActivationFunctionType Ac>
+void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride, int pad_width, int pad_height,
+ int depth_multiplier, const Dims<4>& output_dims) {
+ const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+ std::vector<float> output_data(output_buffer_size);
+ std::vector<float> reference_output_data(output_buffer_size);
+ reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
+ filter_dims, bias_data, bias_dims, stride,
+ pad_width, pad_height, depth_multiplier,
+ reference_output_data.data(), output_dims);
+ optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
+ filter_dims, bias_data, bias_dims, stride,
+ pad_width, pad_height, depth_multiplier,
+ output_data.data(), output_dims);
+ double sum_abs_diff = 0;
+ float max_abs_val = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ sum_abs_diff += std::abs(output_data[i] - reference_output_data[i]);
+ max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
+ }
+ if (sum_abs_diff != 0.f) {
+ const float mean_diff =
+ static_cast<float>(sum_abs_diff / output_buffer_size);
+ const float relative_error = std::abs(mean_diff) / max_abs_val;
+ ASSERT_LT(relative_error, 1e-5f);
+ }
+}
+
+void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
+ const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride, int pad_width, int pad_height,
+ int depth_multiplier, const Dims<4>& output_dims) {
+#define TOCO_HANDLE_CASE(AC_TYPE) \
+ if (AC_TYPE == Ac) { \
+ TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data, \
+ filter_dims, bias_data, bias_dims, stride, \
+ pad_width, pad_height, depth_multiplier, \
+ output_dims); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
+#undef TOCO_HANDLE_CASE
+}
+
+// This function picks some random DepthwiseConv params, which may or may not
+// be legal. If they're not legal, it returns false. If they're legal,
+// it runs the DepthwiseConv test and returns true. This allows the caller
+// to loop until a test has been run.
+bool TryTestOneDepthwiseConv() {
+ // We have to pick a lot of positive values, where we are particularly
+ // interested in small values because they are most likely to be special
+ // cases in optimized implementations, and secondarily because they allow
+ // tests to run fast, which means we can run more tests and get more
+ // coverage.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int output_depth = input_depth * depth_multiplier;
+ // The optimized DepthwiseConv implementation currently uses a fixed-size
+ // accumulator buffer on the stack, with that size. This currently means
+ // that it does not support larger output depths. It CHECK's for it,
+ // so it's safe in the sense that if a larger output depth was encountered,
+ // it would explicitly fail. We just need to adjust our testing to that
+ // constraint.
+ const int kMaxSupportedOutputDepth = 1024;
+ if (output_depth > kMaxSupportedOutputDepth) {
+ return false;
+ }
+ const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu6,
+ FusedActivationFunctionType::kRelu1}));
+ Dims<4> input_dims_inference =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ Dims<4> output_dims_inference;
+ int pad_width, pad_height;
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+ if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
+ filter_height, stride, padding_type,
+ &output_dims_inference, &pad_width, &pad_height)) {
+ return false;
+ }
+ Dims<4> filter_dims_inference =
+ MakeDimsForInference(output_depth, filter_width, filter_height, 1);
+ Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
+ const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
+ const int filter_buffer_size =
+ RequiredBufferSizeForDims(filter_dims_inference);
+ std::vector<float> input_data(input_buffer_size);
+ std::vector<float> filter_data(filter_buffer_size);
+ std::vector<float> bias_data(output_depth);
+ const float input_amplitude = 1.f;
+ const float filter_amplitude = 1.f;
+ const float bias_amplitude =
+ filter_width * filter_height * input_amplitude * filter_amplitude;
+ FillRandom(&input_data, -input_amplitude, input_amplitude);
+ FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
+ FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
+ TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
+ filter_data.data(), filter_dims_inference,
+ bias_data.data(), bias_dims_inference, stride, pad_width,
+ pad_height, depth_multiplier, output_dims_inference);
+ return true;
+}
+
+void TestOneDepthwiseConv() {
+ while (!TryTestOneDepthwiseConv()) {
+ }
+}
+
+TEST(TestDepthwiseConv, TestDepthwiseConv) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv();
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
new file mode 100644
index 0000000000..2c0fc8433e
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -0,0 +1,330 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <sys/types.h>
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <cstdlib>
+#include <iterator>
+#include <limits>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+
+namespace tflite {
+namespace {
+
+// Runs the DepthwiseConv and compares against the reference implementation.
+template <FusedActivationFunctionType Ac>
+int TestOneDepthwiseConvWithGivenOutputShift(
+ const std::uint8_t* input_data, const Dims<4>& input_dims,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const Dims<4>& filter_dims, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ int output_shift, std::int32_t output_activation_min,
+ std::int32_t output_activation_max, const Dims<4>& output_dims) {
+ const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+ std::vector<std::uint8_t> output_data(output_buffer_size);
+ std::vector<std::uint8_t> reference_output_data(output_buffer_size);
+ reference_ops::DepthwiseConv<Ac>(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max,
+ reference_output_data.data(), output_dims);
+ optimized_ops::DepthwiseConv<Ac>(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data.data(),
+ output_dims);
+ int saturated_min = 0;
+ int saturated_max = 0;
+ std::vector<int> diff(output_buffer_size);
+ std::int64_t sum_diff = 0;
+ std::int64_t sum_abs_diff = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ diff[i] = static_cast<int>(output_data[i]) -
+ static_cast<int>(reference_output_data[i]);
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ saturated_min += output_data[i] == output_activation_min;
+ saturated_max += output_data[i] == output_activation_max;
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / output_buffer_size;
+ const float mean_abs_diff =
+ static_cast<float>(sum_abs_diff) / output_buffer_size;
+ // Normally we should require bit-for-bit exact results. Unfortunately a bug
+ // in the Intel arm_neon_sse.h translation header that we use for x86 tests
+ // causes 1-bit inaccuracy in
+ // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized
+ // DepthwiseConv ops. So we have to live with a few off-by-one errors for now,
+ // yet still ensure that no more than a small minority of values are wrong.
+ EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1);
+ if (saturated_min > 2 * saturated_max) {
+ return -1;
+ }
+ if (saturated_max > 2 * saturated_min) {
+ return 1;
+ }
+ return 0;
+}
+
+// The point of this function is that we can't practically know which
+// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we
+// could do some
+// statistics for large size, but they would be fragile at smaller sizes), and
+// guessing wrong would mean that all the values get saturated so the test
+// becomes
+// vacuous. So we just bisect our way to reasonable output_shift values.
+template <FusedActivationFunctionType Ac>
+void TestOneDepthwiseConvBisectOutputShift(
+ const std::uint8_t* input_data, const Dims<4>& input_dims,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const Dims<4>& filter_dims, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ int output_activation_bisect_start, int output_activation_bisect_end,
+ std::int32_t output_activation_min, std::int32_t output_activation_max,
+ const Dims<4>& output_dims) {
+ ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end)
+ << "Bisection failed ?!?!";
+ int output_shift_bisect_midpoint =
+ (output_activation_bisect_start + output_activation_bisect_end) / 2;
+ int bisect_result = TestOneDepthwiseConvWithGivenOutputShift<Ac>(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier,
+ output_shift_bisect_midpoint, output_activation_min,
+ output_activation_max, output_dims);
+ // At this point we know that the test succeeded (otherwise it would have
+ // aborted).
+ if (bisect_result == 0) {
+ // The result isn't particularly saturated on one or the other side.
+ // All good, we're done.
+ return;
+ }
+ if (output_activation_bisect_start == output_activation_bisect_end - 1) {
+ // There is still some saturation on one side, but the bisection is
+ // finished anyways. We're done; nothing more we can do about it. This
+ // happens
+ // in particular when using an activation with a narrow range.
+ return;
+ }
+ // Continue the bisection based on the present result.
+ int new_output_activation_bisect_start = bisect_result == 1
+ ? output_shift_bisect_midpoint
+ : output_activation_bisect_start;
+ int new_output_activation_bisect_end = bisect_result == 1
+ ? output_activation_bisect_end
+ : output_shift_bisect_midpoint;
+ TestOneDepthwiseConvBisectOutputShift<Ac>(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier,
+ new_output_activation_bisect_start, new_output_activation_bisect_end,
+ output_activation_min, output_activation_max, output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+void TestOneDepthwiseConv(
+ const std::uint8_t* input_data, const Dims<4>& input_dims,
+ std::int32_t input_offset, const std::uint8_t* filter_data,
+ const Dims<4>& filter_dims, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ std::int32_t output_offset, std::int32_t output_multiplier,
+ std::int32_t output_activation_min, std::int32_t output_activation_max,
+ const Dims<4>& output_dims) {
+ TestOneDepthwiseConvBisectOutputShift<Ac>(
+ input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ depth_multiplier, output_offset, output_multiplier, 0, 32,
+ output_activation_min, output_activation_max, output_dims);
+}
+
+void TestOneDepthwiseConv(
+ FusedActivationFunctionType Ac, const std::uint8_t* input_data,
+ const Dims<4>& input_dims, std::int32_t input_offset,
+ const std::uint8_t* filter_data, const Dims<4>& filter_dims,
+ std::int32_t filter_offset, const std::int32_t* bias_data,
+ const Dims<4>& bias_dims, int stride, int pad_width, int pad_height,
+ int depth_multiplier, std::int32_t output_offset,
+ std::int32_t output_multiplier, std::int32_t output_activation_min,
+ std::int32_t output_activation_max, const Dims<4>& output_dims) {
+#define TOCO_HANDLE_CASE(AC_TYPE) \
+ if (AC_TYPE == Ac) { \
+ TestOneDepthwiseConv<AC_TYPE>( \
+ input_data, input_dims, input_offset, filter_data, filter_dims, \
+ filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \
+ depth_multiplier, output_offset, output_multiplier, \
+ output_activation_min, output_activation_max, output_dims); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
+ TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
+#undef TOCO_HANDLE_CASE
+}
+
+bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
+ int input_height, int filter_width, int filter_height,
+ int depth_multiplier, int stride,
+ PaddingType padding_type) {
+ const int output_depth = input_depth * depth_multiplier;
+ // The optimized DepthwiseConv implementation currently uses a fixed-size
+ // accumulator buffer on the stack, with that size. This currently means
+ // that it does not support larger output depths. It CHECK's for it,
+ // so it's safe in the sense that if a larger output depth was encountered,
+ // it would explicitly fail. We just need to adjust our testing to that
+ // constraint.
+ const int kMaxSupportedOutputDepth = 1024;
+ if (output_depth > kMaxSupportedOutputDepth) {
+ return false;
+ }
+ const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu6,
+ FusedActivationFunctionType::kRelu1}));
+ int output_activation_min = 0;
+ int output_activation_max = 255;
+ if (ac != FusedActivationFunctionType::kNone && UniformRandomInt(0, 1)) {
+ output_activation_min = UniformRandomInt(0, 50);
+ output_activation_max = UniformRandomInt(200, 255);
+ }
+ const std::int32_t output_multiplier =
+ UniformRandomInt(1 << 29, std::numeric_limits<std::int32_t>::max());
+ const std::int32_t input_offset = UniformRandomInt(-256, 0);
+ const std::int32_t filter_offset = UniformRandomInt(-256, 0);
+ const std::int32_t output_offset = UniformRandomInt(-256, 0);
+ Dims<4> input_dims_inference =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ Dims<4> output_dims_inference;
+ int pad_width, pad_height;
+ if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
+ filter_height, stride, padding_type,
+ &output_dims_inference, &pad_width, &pad_height)) {
+ return false;
+ }
+ Dims<4> filter_dims_inference =
+ MakeDimsForInference(output_depth, filter_width, filter_height, 1);
+ Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
+ const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
+ const int filter_buffer_size =
+ RequiredBufferSizeForDims(filter_dims_inference);
+ std::vector<std::uint8_t> input_data(input_buffer_size);
+ std::vector<std::uint8_t> filter_data(filter_buffer_size);
+ std::vector<std::int32_t> bias_data(output_depth);
+ FillRandom(&input_data);
+ FillRandom(&filter_data);
+ FillRandom(&bias_data, -10000, 10000);
+ TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
+ input_offset, filter_data.data(), filter_dims_inference,
+ filter_offset, bias_data.data(), bias_dims_inference,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_activation_min,
+ output_activation_max, output_dims_inference);
+ return true;
+}
+
+// This function picks some random DepthwiseConv params, which may or may not
+// be legal. If they're not legal, it returns false. If they're legal,
+// it runs the DepthwiseConv test and returns true. This allows the caller
+// to loop until a test has been run.
+bool TryTestOneDepthwiseConv() {
+ // We have to pick a lot of positive values, where we are particularly
+ // interested in small values because they are most likely to be special
+ // cases in optimized implementations, and secondarily because they allow
+ // tests to run fast, which means we can run more tests and get more
+ // coverage.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
+ const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+
+ return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
+ filter_width, filter_height, depth_multiplier,
+ stride, padding_type);
+}
+
+// Tests parameters for the 3x3 filter kernel.
+bool TryTestOneDepthwiseConv3x3Filter() {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int filter_width = 3;
+ const int filter_height = 3;
+ const int depth_multiplier = 1;
+ const int stride = UniformRandomInt(1, 2);
+ // Although the kernel supports only kValid padding, we test that kSame
+ // is using the correct code path.
+ const auto padding_type =
+ UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
+
+ return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
+ filter_width, filter_height, depth_multiplier,
+ stride, padding_type);
+}
+
+void TestOneDepthwiseConv() {
+ while (!TryTestOneDepthwiseConv()) {
+ }
+}
+
+void TestOneDepthwiseConv3x3Filter() {
+ while (!TryTestOneDepthwiseConv3x3Filter()) {
+ }
+}
+
+TEST(TestDepthwiseConv, TestDepthwiseConv) {
+ const int kTestsToRun = 10 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv();
+ }
+}
+
+TEST(TestDepthwiseConv3x3Filter, TestDepthwiseConv) {
+ const int kTestsToRun = 3 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ TestOneDepthwiseConv3x3Filter();
+ }
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 5f9cfc450d..3bbaaa6a9d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -57,12 +57,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
- // TODO(mirkov): change std::minmax_element with a vectorized call.
- auto minmax_element = std::minmax_element(
- input_ptr_batch, input_ptr_batch + batch_size * input_size);
-
// Save quantization and matmul computation for all zero input.
- if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
// Quantize input from float to uint8 + quantization params (scaling
// factor).
float unused_min, unused_max;
@@ -83,10 +79,9 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
delete[] scaling_factors;
}
- minmax_element = std::minmax_element(
- hidden_state_ptr_batch, hidden_state_ptr_batch + batch_size * num_units);
// Save quantization and matmul computation for all zero input.
- if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) {
+ if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
+ batch_size * num_units)) {
// Quantize hidden_state
float unused_min, unused_max;
float* scaling_factors = new float[batch_size];
diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc
new file mode 100644
index 0000000000..7e9ff5242a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc
@@ -0,0 +1,333 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+
+namespace {
+
+class NumberGenerator {
+ public:
+ std::vector<int> RandomIntVector(int n, int min_val, int max_val) {
+ std::vector<int> vec(n);
+ double scale = static_cast<double>(max_val + 1 - min_val) / engine_.max();
+ for (auto& it : vec) {
+ it = min_val + std::floor(engine_() * scale);
+ }
+ return vec;
+ }
+
+ std::mt19937 engine_;
+};
+
+class LogQuantizedTest : public ::testing::Test {
+ public:
+ NumberGenerator generator_;
+};
+
+// input_integer_bits <= 30. output_integer_bits > 0.
+inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits,
+ int output_integer_bits) {
+ const double float_log_sum_of_exps = std::log(
+ static_cast<double>(input_val) * 0.5 / (1 << (30 - input_integer_bits)));
+ static constexpr double min_int =
+ static_cast<double>(std::numeric_limits<int32>::min());
+ static constexpr double max_int =
+ static_cast<double>(std::numeric_limits<int32>::max());
+ double double_result = tflite::TfLiteRound(float_log_sum_of_exps *
+ (1 << (31 - output_integer_bits)));
+ return static_cast<std::int32_t>(
+ std::min(max_int, std::max(min_int, double_result)));
+}
+
+void CheckOutputData(const std::vector<int32>& test_output,
+ const std::vector<int32>& reference_output,
+ const std::vector<int32>& test_input,
+ const string& check_label, int input_integer_bits,
+ int output_integer_bits, int tolerance) {
+ // In the special case of small input, specifically raw value of 5, a rounding
+ // up leads to difference in the output. We do not aim to be accurate for
+ // very small input values, and there should be sufficient input fractional
+ // bits that this is a small input.
+ static constexpr double error_from_rounding_up = 0.0224585;
+ const int n = test_output.size();
+ ASSERT_EQ(n, reference_output.size());
+ for (int i = 0; i < n; ++i) {
+ // Adjust tolerance when input <= 5*2^-(31-input_integer_bits).
+ const int adjusted_tolerance =
+ test_input[i] > 5
+ ? tolerance
+ : std::max(tolerance, static_cast<int>(std::ceil(
+ error_from_rounding_up *
+ (1 << (31 - output_integer_bits)))));
+ ASSERT_LE(std::abs(test_output[i] - reference_output[i]),
+ adjusted_tolerance)
+ << "Failure in \"" << check_label << "\" at i=" << i
+ << ", test_input[i]=" << test_input[i] << "="
+ << static_cast<double>(test_input[i]) / (1 << (31 - input_integer_bits))
+ << ", test_output[i]=" << test_output[i] << "="
+ << static_cast<double>(test_output[i]) /
+ (1 << (31 - output_integer_bits))
+ << ", reference_output[i]=" << reference_output[i] << "="
+ << static_cast<double>(reference_output[i]) /
+ (1 << (31 - output_integer_bits))
+ << ", difference[i]=" << std::abs(reference_output[i] - test_output[i])
+ << "="
+ << static_cast<double>(std::abs(reference_output[i] - test_output[i])) /
+ (1 << (31 - output_integer_bits))
+ << "; tolerance=" << tolerance
+ << ", adj tolerance=" << adjusted_tolerance;
+ }
+}
+
+void RightShiftVector(const std::vector<int32>& shifts,
+ std::vector<int32>* vec) {
+ const int n = vec->size();
+ ASSERT_EQ(n, shifts.size());
+ for (int i = 0; i < n; ++i) {
+ vec->at(i) = std::max(1, vec->at(i) >> shifts[i]);
+ }
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+void RunSingleTest(const std::vector<int32>& test_input,
+ const string& check_label, int tolerance) {
+ const int n = test_input.size();
+ std::vector<int32> float_gen_output(n, 0);
+ std::vector<int32> reference_output(n, 0);
+ std::vector<int32> optimized_output(n, 0);
+
+ // Workaround the stupid things that intelligent humans do.
+ // Consequence of __builtin_clz(0u) may equal 31 instead of 32.
+ std::vector<int32> fudged_input(n, 0);
+ for (int i = 0; i < n; ++i) {
+ fudged_input[i] = std::max(test_input[i], 2);
+ }
+
+ for (int i = 0; i < n; ++i) {
+ reference_output[i] =
+ tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
+ OutputIntegerBits, InputIntegerBits>(
+ gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
+ fudged_input[i]))
+ .raw();
+ optimized_output[i] =
+ tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl<
+ OutputIntegerBits, InputIntegerBits>(
+ gemmlowp::FixedPoint<int32, InputIntegerBits>::FromRaw(
+ fudged_input[i]))
+ .raw();
+ float_gen_output[i] = LogPositiveValuesViaFloat(
+ fudged_input[i], InputIntegerBits, OutputIntegerBits);
+ }
+ // Note that first check is intolerant.
+ {
+ std::ostringstream label;
+ label << check_label << " / optimized vs reference / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ optimized_output, reference_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, 0);
+ }
+ {
+ std::ostringstream label;
+ label << check_label << " / reference vs float-gen / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ reference_output, float_gen_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, tolerance);
+ }
+ {
+ std::ostringstream label;
+ label << check_label << " optimized vs float-gen / InputIntegerBits="
+ << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits;
+ CheckOutputData(
+ optimized_output, float_gen_output, test_input, label.str(),
+ InputIntegerBits, OutputIntegerBits, tolerance);
+ }
+}
+
+template <int OutputIntegerBits>
+void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
+ const string& check_label, int tolerance) {
+#define INPUT_CASE(K) \
+ case K: \
+ return RunSingleTest<OutputIntegerBits, K>(test_input, check_label, \
+ tolerance)
+ switch (input_integer_bits) {
+ INPUT_CASE(0);
+ INPUT_CASE(1);
+ INPUT_CASE(2);
+ INPUT_CASE(3);
+ INPUT_CASE(4);
+ INPUT_CASE(5);
+ INPUT_CASE(6);
+ INPUT_CASE(7);
+ INPUT_CASE(8);
+ INPUT_CASE(9);
+ INPUT_CASE(10);
+ INPUT_CASE(11);
+ INPUT_CASE(12);
+ INPUT_CASE(13);
+ INPUT_CASE(14);
+ INPUT_CASE(15);
+ INPUT_CASE(16);
+ INPUT_CASE(17);
+ INPUT_CASE(18);
+ INPUT_CASE(19);
+ INPUT_CASE(20);
+ INPUT_CASE(21);
+ INPUT_CASE(22);
+ INPUT_CASE(23);
+ INPUT_CASE(24);
+ INPUT_CASE(25);
+ INPUT_CASE(26);
+ INPUT_CASE(27);
+ INPUT_CASE(28);
+ INPUT_CASE(29);
+ default:
+ ASSERT_LE(input_integer_bits, 30)
+ << "Input integer bits not handled: " << input_integer_bits;
+ }
+#undef INPUT_CASE
+}
+
+void RunSingleTest(const std::vector<int32>& test_input, int input_integer_bits,
+ int output_integer_bits, const string& check_label,
+ int tolerance) {
+#define OUTPUT_CASE(K) \
+ case K: \
+ return RunSingleTest<K>(test_input, input_integer_bits, check_label, \
+ tolerance)
+ switch (output_integer_bits) {
+ OUTPUT_CASE(0);
+ OUTPUT_CASE(1);
+ OUTPUT_CASE(2);
+ OUTPUT_CASE(3);
+ OUTPUT_CASE(4);
+ OUTPUT_CASE(5);
+ OUTPUT_CASE(6);
+ OUTPUT_CASE(7);
+ OUTPUT_CASE(8);
+ OUTPUT_CASE(9);
+ OUTPUT_CASE(10);
+ OUTPUT_CASE(11);
+ OUTPUT_CASE(12);
+ OUTPUT_CASE(13);
+ OUTPUT_CASE(14);
+ OUTPUT_CASE(15);
+ OUTPUT_CASE(16);
+ OUTPUT_CASE(17);
+ OUTPUT_CASE(18);
+ OUTPUT_CASE(19);
+ OUTPUT_CASE(20);
+ OUTPUT_CASE(21);
+ OUTPUT_CASE(22);
+ OUTPUT_CASE(23);
+ OUTPUT_CASE(24);
+ OUTPUT_CASE(25);
+ OUTPUT_CASE(26);
+ OUTPUT_CASE(27);
+ OUTPUT_CASE(28);
+ OUTPUT_CASE(29);
+ default:
+ ASSERT_LE(input_integer_bits, 30)
+ << "Input integer bits not handled: " << input_integer_bits;
+ }
+#undef OUTPUT_CASE
+}
+
+void RunUniformTest(int test_size, int input_integer_bits,
+ int output_integer_bits, const string& check_label,
+ int tolerance, NumberGenerator* generator) {
+ std::vector<int> test_data = generator->RandomIntVector(
+ test_size, 2, std::numeric_limits<int>::max() - 1);
+ test_data[0] = 2;
+ test_data[1] = 3;
+ test_data[2] = 4;
+ test_data[3] = std::numeric_limits<int32>::max() - 2;
+ test_data[4] = std::numeric_limits<int32>::max() - 1;
+ test_data[5] = std::numeric_limits<int32>::max();
+
+ RunSingleTest(test_data, input_integer_bits, output_integer_bits,
+ check_label + " / uniform test", tolerance);
+}
+
+void RunUniformShiftUniformTest(int test_size, int input_integer_bits,
+ int output_integer_bits,
+ const string& check_label, int tolerance,
+ NumberGenerator* generator) {
+ std::vector<int> test_data = generator->RandomIntVector(
+ test_size, 2, std::numeric_limits<int>::max() - 1);
+ std::vector<int> shifts = generator->RandomIntVector(test_size, 0, 29);
+ RightShiftVector(shifts, &test_data);
+
+ RunSingleTest(test_data, input_integer_bits, output_integer_bits,
+ check_label + " / shifted test", tolerance);
+}
+
+TEST_F(LogQuantizedTest, VariedIntegerBits) {
+ static constexpr int kVariations = 250;
+ static constexpr int kRunSize = 250;
+ static constexpr int kIntegerTolerance = 8;
+ static constexpr double kOutputFloatTolerance = 7.0e-7;
+
+ std::vector<int> input_integer_bits =
+ generator_.RandomIntVector(kVariations, 0, 24);
+ std::vector<int> output_integer_bits =
+ generator_.RandomIntVector(kVariations, 1, 10);
+
+ for (int i = 0; i < kVariations; ++i) {
+ int var_output_integer_bits = output_integer_bits[i];
+ int tolerance =
+ std::max(1.0 * kIntegerTolerance,
+ (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance);
+
+ RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits,
+ "VariedIntegerBits", tolerance, &generator_);
+ RunUniformShiftUniformTest(kRunSize, input_integer_bits[i],
+ var_output_integer_bits, "VariedIntegerBits",
+ tolerance, &generator_);
+ }
+}
+
+TEST_F(LogQuantizedTest, SelectedIntegerBits) {
+ static constexpr int kInputBits = 12;
+ static constexpr int kOutputBits = 5;
+ static constexpr int kRunSize = 100000;
+ static constexpr int kIntegerTolerance = 4;
+
+ RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits",
+ kIntegerTolerance, &generator_);
+ RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits,
+ "SelectedIntegerBits", kIntegerTolerance,
+ &generator_);
+}
+
+} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
new file mode 100644
index 0000000000..b7531ea2e2
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -0,0 +1,241 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+
+namespace tflite {
+namespace {
+
+void RunLogSoftmaxFloatReference(const uint8* input_data,
+ const Dims<4>& dims_common, int32 input_offset,
+ const double input_scale, int stride,
+ float beta, uint8* reference_output_data) {
+ const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ std::vector<float> reference_dequant_data(ref_buffer_size);
+ std::vector<float> reference_output_float_data(ref_buffer_size);
+
+ // Reference data generated via Dequant of input into float, and then applying
+ // float LogSoftmax.
+ reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
+ reference_dequant_data.data(), dims_common);
+ optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common,
+ reference_output_float_data.data(), dims_common);
+ // Work with quantized scaling for LogSoftmax, under which 255 represents 0,
+ // and -16 gets nudged up to 0.
+ for (int i = 0; i < ref_buffer_size; i++) {
+ reference_output_data[i] = std::max(
+ 0, static_cast<int>(
+ 255 + std::round(16.0f * reference_output_float_data[i])));
+ }
+}
+
+void CheckOutputData(const uint8* test_output, const uint8* reference_output,
+ const Dims<4>& dims_common, const string& check_label,
+ bool be_exacting) {
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ // While calculating some metrics in floating point, we work with quantized
+ // scaling.
+ std::vector<int> diff(buffer_size);
+ int64_t sum_diff = 0;
+ int64_t sum_abs_diff = 0;
+ for (int i = 0; i < buffer_size; i++) {
+ diff[i] = static_cast<int>(test_output[i]) - reference_output[i];
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / buffer_size;
+ const float mean_abs_diff = static_cast<float>(sum_abs_diff) / buffer_size;
+ // We either check for bit exactness (against the reference quantized version)
+ // or for general accuracy, allowing off-by-one (against the float reference).
+ if (be_exacting) {
+ ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0)
+ << check_label << ": "
+ << "std::abs(min_diff)=" << std::abs(min_diff)
+ << ", std::abs(max_diff)=" << std::abs(max_diff);
+ } else {
+ // For small numbers of samples, the estimates of the means vary more.
+ // Rather than widen the tolerances, we skip the smaller tests.
+ ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) ||
+ buffer_size < 10000) &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1)
+ << check_label << ": "
+ << "buffer_size=" << buffer_size << ", mean_diff=" << mean_diff
+ << ", mean_abs_diff=" << mean_abs_diff
+ << ", median_diff=" << median_diff << ", min_diff=" << min_diff
+ << ", max_diff=" << max_diff;
+ }
+}
+
+// Runs the LogSoftmax and compares against the float reference implementation
+// and the quantized reference implementation.
+void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta) {
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ std::vector<uint8> optimized_logsoftmax_output(buffer_size);
+ std::vector<uint8> reference_float_logsoftmax_output(buffer_size);
+ std::vector<uint8> reference_quant_logsoftmax_output(buffer_size);
+
+ RunLogSoftmaxFloatReference(input_data, dims_common, input_offset,
+ input_scale, stride, beta,
+ reference_float_logsoftmax_output.data());
+
+ int32 input_beta_multiplier;
+ int input_beta_left_shift;
+ int32 reverse_scaling_divisor;
+ int reverse_scaling_right_shift;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessLogSoftmaxScaling(
+ beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
+ &input_beta_left_shift, &reverse_scaling_divisor,
+ &reverse_scaling_right_shift);
+ // diff_min has a negative value, and is used to limit the maximum magnitude
+ // of the diffs, which are <= 0.
+ const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+ input_beta_left_shift);
+
+ optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier,
+ input_beta_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min,
+ optimized_logsoftmax_output.data(), dims_common);
+ reference_ops::LogSoftmax(
+ input_data, dims_common, input_beta_multiplier, input_beta_left_shift,
+ reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
+ reference_quant_logsoftmax_output.data(), dims_common);
+
+ CheckOutputData(optimized_logsoftmax_output.data(),
+ reference_float_logsoftmax_output.data(), dims_common,
+ "Optimized vs float reference", false);
+ CheckOutputData(optimized_logsoftmax_output.data(),
+ reference_quant_logsoftmax_output.data(), dims_common,
+ "Optimized vs quant reference", true);
+ CheckOutputData(reference_quant_logsoftmax_output.data(),
+ reference_float_logsoftmax_output.data(), dims_common,
+ "Quant reference vs float reference", false);
+}
+
+// This function picks some random LogSoftmax params, which are checked for
+// desirability. If not acceptable, it returns false. If they're OK,
+// it runs the LogSoftmax test and returns true. This allows the caller
+// to loop until a test has been run.
+//
+// Currently we do not reject for any reason.
+bool TryOneUniformLogSoftmax() {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // LogSoftmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ static constexpr float beta = 1.0f;
+
+ Dims<4> dims_common =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandom(&input_data);
+ RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ input_scale, stride, beta);
+ return true;
+}
+
+// See TryOneUniformLogSoftmax() for a general description.
+//
+// Tests with "skyscraper" input patterns are included for two reasons. (a)
+// Bimodal distributions are potentially challenging and perhaps more
+// realistic than simple uniform random inputs. (b) Some implementations of
+// LogSoftmax may adapt as they traverse the depth, and so we test handling of
+// cases where relatively small values are encountered at the beginning and end.
+bool TryOneSkyscraperLogSoftmax(bool small_depth) {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // LogSoftmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = small_depth
+ ? ExponentialRandomPositiveInt(0.75f, 40, 500)
+ : ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ static constexpr float beta = 1.0f;
+ // Extra parameters for skyscraper input patterns.
+ const double middle_proportion =
+ ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0);
+ const int middle_min = UniformRandomInt(0, 255);
+ const int sides_max = UniformRandomInt(0, middle_min);
+
+ Dims<4> dims_common =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
+ sides_max);
+ RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ input_scale, stride, beta);
+ return true;
+}
+
+TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneUniformLogSoftmax()) {
+ }
+ }
+}
+
+TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperLogSoftmax(false)) {
+ }
+ }
+}
+
+TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperLogSoftmax(true)) {
+ }
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index dd6932ffe7..3fd00c8930 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1691,14 +1691,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
const int filter_width = ArraySize(filter_dims, 1);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
+#ifdef USE_NEON
+ const bool shift_left = (output_shift <= 0);
+ const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1;
+#endif
TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
-#ifdef __aarch64__
+// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
+// Jetson TX-2. This compiler does not support the offsetof() macro.
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
- if (Fast3x3FilterKernelSupported(input_dims, filter_dims, stride_width,
- stride_height, pad_width, pad_height,
- depth_multiplier, output_dims)) {
+ if (Fast3x3FilterKernelSupported(
+ input_dims, filter_dims, stride_width, stride_height, pad_width,
+ pad_height, depth_multiplier, output_dims, output_shift)) {
DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data,
filter_dims, filter_offset, bias_data, bias_dims,
stride_width, stride_height, pad_width, pad_height,
@@ -1833,12 +1839,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
}
- // Fixed-point multiplication.
- for (int j = 0; j < 4; j++) {
- acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
- }
- for (int j = 0; j < 4; j++) {
- acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
+ for (int j = 0; j < 4; j++) {
+ acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ }
+ } else {
+ // Fixed-point multiplication.
+ for (int j = 0; j < 4; j++) {
+ acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two);
+ acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
+ }
}
// Add the output offset.
for (int j = 0; j < 4; j++) {
@@ -1870,12 +1884,21 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
for (; i <= num_output_values - 8; i += 8) {
int32x4_t acc0 = vld1q_s32(acc_buffer + i);
int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
- // Fixed-point multiplication.
- acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
- acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
- // Rounding right shift.
- acc0 = RoundingDivideByPOT(acc0, output_shift);
- acc1 = RoundingDivideByPOT(acc1, output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ // Rounding right shift.
+ acc0 = RoundingDivideByPOT(acc0, output_shift);
+ acc1 = RoundingDivideByPOT(acc1, output_shift);
+ } else {
+ // Fixed-point multiplication.
+ acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
+ acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
+
+ acc1 = vmulq_n_s32(acc1, multiplier_power_of_two);
+ acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
+ }
// Add the output offset.
acc0 = vaddq_s32(acc0, output_offset_vec);
acc1 = vaddq_s32(acc1, output_offset_vec);
@@ -1899,10 +1922,16 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// that will have to go through the very slow scalar code.
for (; i <= num_output_values - 4; i += 4) {
int32x4_t acc = vld1q_s32(acc_buffer + i);
- // Fixed-point multiplication.
- acc = vqrdmulhq_n_s32(acc, output_multiplier);
- // Rounding right shift.
- acc = RoundingDivideByPOT(acc, output_shift);
+ if (!shift_left) {
+ // Fixed-point multiplication.
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ // Rounding right shift.
+ acc = RoundingDivideByPOT(acc, output_shift);
+ } else {
+ // Fixed-point multiplication.
+ acc = vmulq_n_s32(acc, multiplier_power_of_two);
+ acc = vqrdmulhq_n_s32(acc, output_multiplier);
+ }
// Add the output offset.
acc = vaddq_s32(acc, output_offset_vec);
// Apply the activation function.
@@ -1923,8 +1952,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Handle leftover values, one by one. This is very slow.
for (; i < num_output_values; i++) {
int32 acc = acc_buffer[i];
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 55e0d5c3aa..8cd72239e9 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -23,3848 +23,2191 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
-#ifdef __aarch64__
-
-inline void preload_l1_keep(const uint8* ptr) {
-#ifdef GEMMLOWP_ARM_64
- asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
-#else
- gemmlowp::Prefetch(ptr);
-#endif
-}
-
-// Implementation of quantized DepthwiseConv for 3x3 filters.
-
-// Below are helper structs to remove the use of arrays.
-// There is an llvm bug that causes significant slowdown when using arrays for
-// NEON intrinsics vector data types.
-// See: https://bugs.llvm.org/show_bug.cgi?id=34945
-
-struct Int32x8 {
- int32x4_t low, high;
-};
-
-struct Filter3x3x8 {
- int16x8_t f0, f1, f2, f3, f4, f5, f6, f7, f8;
-};
-
-// Loads 3x3 filter of depth 8 and adds filter offsets.
-inline Filter3x3x8 Load3x3Filter(const uint8* filter_ptr, int32 filter_offset,
- int output_depth) {
- Filter3x3x8 filter;
-
- uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5,
- temp_u8_6, temp_u8_7, temp_u8_8;
- int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
-
- temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth);
- temp_u8_1 = vld1_u8(filter_ptr + 1 * output_depth);
- temp_u8_2 = vld1_u8(filter_ptr + 2 * output_depth);
- temp_u8_3 = vld1_u8(filter_ptr + 3 * output_depth);
- temp_u8_4 = vld1_u8(filter_ptr + 4 * output_depth);
- temp_u8_5 = vld1_u8(filter_ptr + 5 * output_depth);
- temp_u8_6 = vld1_u8(filter_ptr + 6 * output_depth);
- temp_u8_7 = vld1_u8(filter_ptr + 7 * output_depth);
- temp_u8_8 = vld1_u8(filter_ptr + 8 * output_depth);
-
- filter.f0 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0));
- filter.f1 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1));
- filter.f2 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2));
- filter.f3 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3));
- filter.f4 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4));
- filter.f5 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5));
- filter.f6 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6));
- filter.f7 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7));
- filter.f8 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8));
-
- filter.f0 = vaddq_s16(filter.f0, filter_offset_vec);
- filter.f1 = vaddq_s16(filter.f1, filter_offset_vec);
- filter.f2 = vaddq_s16(filter.f2, filter_offset_vec);
- filter.f3 = vaddq_s16(filter.f3, filter_offset_vec);
- filter.f4 = vaddq_s16(filter.f4, filter_offset_vec);
- filter.f5 = vaddq_s16(filter.f5, filter_offset_vec);
- filter.f6 = vaddq_s16(filter.f6, filter_offset_vec);
- filter.f7 = vaddq_s16(filter.f7, filter_offset_vec);
- filter.f8 = vaddq_s16(filter.f8, filter_offset_vec);
-
- return filter;
-}
-
-// Applies activation, offset and downquantize on a set of accumulator
-// registers that correspond to a 2x2 output of depth 8.
-// Stores results to output.
-inline void DownquantizeAndStore2x2Output(
- Int32x8 acc_0, Int32x8 acc_1, Int32x8 acc_2, Int32x8 acc_3,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- // Fixed-point multiplication.
- acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier);
- acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier);
- acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier);
- acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier);
- acc_2.low = vqrdmulhq_n_s32(acc_2.low, output_multiplier);
- acc_2.high = vqrdmulhq_n_s32(acc_2.high, output_multiplier);
- acc_3.low = vqrdmulhq_n_s32(acc_3.low, output_multiplier);
- acc_3.high = vqrdmulhq_n_s32(acc_3.high, output_multiplier);
-
- acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift);
- acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift);
- acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift);
- acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift);
- acc_2.low = RoundingDivideByPOT(acc_2.low, output_shift);
- acc_2.high = RoundingDivideByPOT(acc_2.high, output_shift);
- acc_3.low = RoundingDivideByPOT(acc_3.low, output_shift);
- acc_3.high = RoundingDivideByPOT(acc_3.high, output_shift);
-
- // Add the output offset.
- acc_0.low = vaddq_s32(acc_0.low, output_offset_vec);
- acc_0.high = vaddq_s32(acc_0.high, output_offset_vec);
- acc_1.low = vaddq_s32(acc_1.low, output_offset_vec);
- acc_1.high = vaddq_s32(acc_1.high, output_offset_vec);
- acc_2.low = vaddq_s32(acc_2.low, output_offset_vec);
- acc_2.high = vaddq_s32(acc_2.high, output_offset_vec);
- acc_3.low = vaddq_s32(acc_3.low, output_offset_vec);
- acc_3.high = vaddq_s32(acc_3.high, output_offset_vec);
-
- // Apply the activation function.
- acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec);
- acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec);
- acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec);
- acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec);
- acc_2.low = vmaxq_s32(acc_2.low, output_activation_min_vec);
- acc_2.high = vmaxq_s32(acc_2.high, output_activation_min_vec);
- acc_3.low = vmaxq_s32(acc_3.low, output_activation_min_vec);
- acc_3.high = vmaxq_s32(acc_3.high, output_activation_min_vec);
-
- acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec);
- acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec);
- acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec);
- acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec);
- acc_2.low = vminq_s32(acc_2.low, output_activation_max_vec);
- acc_2.high = vminq_s32(acc_2.high, output_activation_max_vec);
- acc_3.low = vminq_s32(acc_3.low, output_activation_max_vec);
- acc_3.high = vminq_s32(acc_3.high, output_activation_max_vec);
-
- // Saturating cast to uint8 and store to destination.
- int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low);
- int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high);
- int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low);
- int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high);
- int16x4_t acc_2_low_s16 = vqmovn_s32(acc_2.low);
- int16x4_t acc_2_high_s16 = vqmovn_s32(acc_2.high);
- int16x4_t acc_3_low_s16 = vqmovn_s32(acc_3.low);
- int16x4_t acc_3_high_s16 = vqmovn_s32(acc_3.high);
-
- int16x8_t res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16);
- int16x8_t res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16);
- int16x8_t res_2_s16 = vcombine_s16(acc_2_low_s16, acc_2_high_s16);
- int16x8_t res_3_s16 = vcombine_s16(acc_3_low_s16, acc_3_high_s16);
-
- uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16);
- uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16);
- uint8x8_t res_2_u8 = vqmovun_s16(res_2_s16);
- uint8x8_t res_3_u8 = vqmovun_s16(res_3_s16);
-
- vst1_u8(output_ptr, res_0_u8);
- vst1_u8(output_ptr + output_depth, res_1_u8);
- vst1_u8(output_ptr + output_depth * output_width, res_2_u8);
- vst1_u8(output_ptr + output_depth * output_width + output_depth, res_3_u8);
-}
-
-inline void DownquantizeAndStore(Int32x8 acc, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max,
- uint8* output_ptr) {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- acc.low = vqrdmulhq_n_s32(acc.low, output_multiplier);
- acc.high = vqrdmulhq_n_s32(acc.high, output_multiplier);
-
- acc.low = RoundingDivideByPOT(acc.low, output_shift);
- acc.high = RoundingDivideByPOT(acc.high, output_shift);
-
- acc.low = vaddq_s32(acc.low, output_offset_vec);
- acc.high = vaddq_s32(acc.high, output_offset_vec);
-
- acc.low = vmaxq_s32(acc.low, output_activation_min_vec);
- acc.high = vmaxq_s32(acc.high, output_activation_min_vec);
-
- acc.low = vminq_s32(acc.low, output_activation_max_vec);
- acc.high = vminq_s32(acc.high, output_activation_max_vec);
-
- int16x4_t acc_low_s16 = vqmovn_s32(acc.low);
- int16x4_t acc_high_s16 = vqmovn_s32(acc.high);
-
- int16x8_t res_s16 = vcombine_s16(acc_low_s16, acc_high_s16);
- uint8x8_t res_u8 = vqmovun_s16(res_s16);
- vst1_u8(output_ptr, res_u8);
-}
-
-inline void DownquantizeAndStore2Output(
- Int32x8 acc_0, Int32x8 acc_1, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- {
- using gemmlowp::RoundingDivideByPOT;
- const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
- const int32x4_t output_activation_min_vec =
- vdupq_n_s32(output_activation_min);
- const int32x4_t output_activation_max_vec =
- vdupq_n_s32(output_activation_max);
-
- // Fixed-point multiplication.
- acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier);
- acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier);
- acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier);
- acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier);
-
- acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift);
- acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift);
- acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift);
- acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift);
-
- // Add the output offset.
- acc_0.low = vaddq_s32(acc_0.low, output_offset_vec);
- acc_0.high = vaddq_s32(acc_0.high, output_offset_vec);
- acc_1.low = vaddq_s32(acc_1.low, output_offset_vec);
- acc_1.high = vaddq_s32(acc_1.high, output_offset_vec);
-
- // Apply the activation function.
- acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec);
- acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec);
- acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec);
- acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec);
-
- acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec);
- acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec);
- acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec);
- acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec);
- }
-
- // Saturating cast to uint8 and store to destination.
- int16x8_t res_0_s16;
- {
- int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low);
- int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high);
- res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16);
- }
-
- int16x8_t res_1_s16;
- {
- int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low);
- int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high);
- res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16);
- }
-
- uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16);
- uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16);
- vst1_u8(output_ptr, res_0_u8);
- vst1_u8(output_ptr + output_ptr_offset, res_1_u8);
-}
-
-// Performs multiply accumulate on 3 inputs of depth 8.
-inline Int32x8 MultiplyAccumulateRow(Int32x8 accum, int16x8_t f0, int16x8_t f1,
- int16x8_t f2, int16x8_t i0, int16x8_t i1,
- int16x8_t i2) {
- accum.low = vmlal_s16(accum.low, vget_low_s16(f0), vget_low_s16(i0));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f0), vget_high_s16(i0));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f1), vget_low_s16(i1));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f1), vget_high_s16(i1));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f2), vget_low_s16(i2));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f2), vget_high_s16(i2));
- return accum;
-}
-
-// Performs multiply accumulate on 3 inputs of depth 8.
-inline Int32x8 MultiplyAccumulate3x3Filter(const Filter3x3x8& f, int16x8_t i0,
- int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4,
- int16x8_t i5, int16x8_t i6,
- int16x8_t i7, int16x8_t i8,
- Int32x8 accum) {
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f0), vget_low_s16(i0));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f0), vget_high_s16(i0));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f1), vget_low_s16(i1));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f1), vget_high_s16(i1));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f2), vget_low_s16(i2));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f2), vget_high_s16(i2));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f3), vget_low_s16(i3));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f3), vget_high_s16(i3));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f4), vget_low_s16(i4));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f4), vget_high_s16(i4));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f5), vget_low_s16(i5));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f5), vget_high_s16(i5));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f6), vget_low_s16(i6));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f6), vget_high_s16(i6));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f7), vget_low_s16(i7));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f7), vget_high_s16(i7));
- accum.low = vmlal_s16(accum.low, vget_low_s16(f.f8), vget_low_s16(i8));
- accum.high = vmlal_s16(accum.high, vget_high_s16(f.f8), vget_high_s16(i8));
- return accum;
-}
-
-inline void DotProductAndStore(const Filter3x3x8& filter, int16x8_t i0,
- int16x8_t i1, int16x8_t i2, int16x8_t i3,
- int16x8_t i4, int16x8_t i5, int16x8_t i6,
- int16x8_t i7, int16x8_t i8,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr) {
- Int32x8 acc;
- acc.low = vld1q_s32(bias_ptr);
- acc.high = vld1q_s32(bias_ptr + 4);
-
- acc = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, i8,
- acc);
-
- DownquantizeAndStore(acc, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr);
-}
-
-// Performs multiply-accumulate on a 3x4 input for 2 horizontal outputs.
-inline void DotProductAndStore2xStride1(
- const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7,
- int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11,
- const int32* bias_ptr, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i4, i5, i6, i8, i9,
- i10, acc_0);
- acc_1 = MultiplyAccumulate3x3Filter(filter, i1, i2, i3, i5, i6, i7, i9, i10,
- i11, acc_1);
- DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_ptr,
- output_ptr_offset);
-}
-
-// Performs multiply-accumulate on a 4x3 input for 2 vertical outputs.
-inline void DotProductAndStore2yStride1(
- const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2,
- int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7,
- int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11,
- const int32* bias_ptr, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- uint8* output_ptr, int output_ptr_offset) {
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7,
- i8, acc_0);
- acc_1 = MultiplyAccumulate3x3Filter(filter, i3, i4, i5, i6, i7, i8, i9, i10,
- i11, acc_1);
- DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_ptr,
- output_ptr_offset);
-}
-
-// A kernel that is optimized on the number of output cells in the x and y
-// direction, and the stride. Assumes 3x3 filters of 8 depth.
-template <int kFixedOutputY, int kFixedOutputX, int kFixedStrideWidth,
- int kFixedStrideHeight>
-struct ConvKernel3x3FilterDepth8 {};
-
-template <>
-struct ConvKernel3x3FilterDepth8<8, 8, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- // To process 8x8 outputs using a 3x3 filter, we require 10x10 inputs.
- // Load inputs for the first 2 filters on the top left, then slide to
- // the right, down, left, down, right, etc. in a snake-like path. This
- // minimizes the total number of loads.
- //
- // INPUT OUTPUT
- // |\----------------\ |\------------\
- // | \ \ | \ \
- // | \----------------\ | \------------\
- // | | 0 ... 9 | | | 0 ... 7 |
- // | | 10 ... 19 | ---> | | 8 ... 15 |
- // | | 20 ... 29 | \ | .. ... .. |
- // \ | .. ... .. | \| 56 ... 63 |
- // \| 90 ... 109 | |------------|
- // |----------------|
- //
- // The first set of loads corresponds to:
- //
- // INPUT OUTPUT
- // |\----------------- |\-----------
- // | \ | \
- // | \----------------- | \----------
- // | | 0 1 2 3 ... | | 0 1 ...
- // | | 10 11 12 13 ... ---> | | .. ...
- // | | 20 21 22 23 ... | .. ...
- // | | .. ... ...
- //
- // The next set of loads correspond to a sliding window to the right.
- // It loads inputs 4, 5, 14, 15, 23, 24 and keeps 2, 3, 12, 13, and 22:
- //
- // INPUT OUTPUT
- // |\------------------- |\-------------
- // | \ | \
- // | \------------------- | \------------
- // | | .. 2 3 4 5 ... | | .. 2 3 ...
- // | | .. 12 13 14 15 ... ---> | | .. ...
- // | | .. 21 22 23 24 ... | .. ...
- // | | .. ... ...
- //
- // And so on...
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left. Referring to the
- // indexes in the diagram above, this corresponds to outputs (0) and (1).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Slide to the right for outputs x = [2, 3], y = 0. Referring to the
- // indexes in the diagram above, this corresponds to outputs (2) and (3).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Slide to the right again for outputs x = [4, 5], y = 0. Referring to the
- // indexes in the diagram above, this corresponds to outputs (4) and (5).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_depth, output_depth);
-
- // Slide to the right one last time for outputs x = [6, 7], y = 0.
- // Referring to the indexes in the diagram above, this corresponds to
- // outputs (6) and (7).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_depth, output_depth);
-
- // Slide to down for outputs x = [6, 7], y = 1. Referring to the indexes in
- // the diagram above, this corresponds to outputs (14) and (15).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_depth + output_row_size,
- output_depth);
-
- // Slide left for outputs x = [4, 5], y = 1. Referring to the indexes in
- // the diagram above, this corresponds to outputs (12) and (13).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_depth + output_row_size,
- output_depth);
-
- // Slide left again for outputs x = [2, 3], y = 1. Referring to the indexes
- // in the diagram above, this corresponds to outputs (10) and (11).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 1. Referring to the
- // indexes in the diagram above, this corresponds to outputs (8) and (9).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (16) and (17).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (18) and (19).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (20) and (21).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 2. Referring to the
- // indexes in the diagram above, this corresponds to outputs (22) and (23).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 2 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (30) and (31).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (28) and (29).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 3. Referring to the indexes in
- // the diagram above, this corresponds to outputs (26) and (27).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 3 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 3. Referring to the
- // indexes in the diagram above, this corresponds to outputs (24) and (25).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 3 * output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (32) and (33).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 4 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (34) and (35).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 4. Referring to the indexes in
- // the diagram above, this corresponds to outputs (36) and (37).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 4. Referring to the
- // indexes in the diagram above, this corresponds to outputs (38) and (39).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 4 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (46) and (47).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (44) and (45).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 5. Referring to the indexes in
- // the diagram above, this corresponds to outputs (42) and (43).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 5 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 5. Referring to the
- // indexes in the diagram above, this corresponds to outputs (40) and (41).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 5 * output_row_size, output_depth);
-
- // Slide down for outputs x = [0, 1], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (48) and (49).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 8 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 6 * output_row_size, output_depth);
-
- // Slide right for outputs x = [2, 3], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (50) and (51).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide right for outputs x = [4, 5], y = 6. Referring to the indexes in
- // the diagram above, this corresponds to outputs (52) and (53).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide right one more time for outputs x = [6, 7], y = 6. Referring to the
- // indexes in the diagram above, this corresponds to outputs (54) and (55).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 8 * input_depth + 6 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 6 * output_row_size, output_depth);
-
- // Slide down for outputs x = [6, 7], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (62) and (63).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 6 * input_depth + 9 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 6 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left for outputs x = [4, 5], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (60) and (61).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 4 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left for outputs x = [2, 3], y = 7. Referring to the indexes in the
- // diagram above, this corresponds to outputs (58) and (59).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 7 * output_row_size, output_depth);
-
- // Slide left one more time for outputs x = [0, 1], y = 7. Referring to the
- // indexes in the diagram above, this corresponds to outputs (56) and (57).
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 7 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 7 * output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- // To process 4x4 outputs using a 3x3 filter, we require 6x6 inputs.
- // Load inputs for the first 2 filters on the top left, then slide to
- // the right, down, left, down, right, etc. in a snake-like path. This
- // minimizes the total number of loads.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the top right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0,
- input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 2 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max,
- output_ptr + 2 * output_depth + 3 * output_row_size, output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 3 * output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 2, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load next inputs one row down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load next row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2,
- input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Now load last row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 5 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 1, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 2x1 outputs starting from the top.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_row_size);
-
- // Load inputs for bottom 2 rows.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_6, input_7, input_8, input_9, input_10, input_11, input_0,
- input_1, input_2, input_3, input_4, input_5, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_row_size,
- output_row_size);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 2, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1, acc_2, acc_3;
-
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_2.low = vld1q_s32(bias_ptr);
- acc_3.low = vld1q_s32(bias_ptr);
-
- bias_ptr += 4;
- acc_0.high = vld1q_s32(bias_ptr);
- acc_1.high = vld1q_s32(bias_ptr);
- acc_2.high = vld1q_s32(bias_ptr);
- acc_3.high = vld1q_s32(bias_ptr);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- // Add scope for input registers to help the compiler know that it is
- // not needed.
- {
- // To process 2x2 outputs using a 3x3 filter, we require 4x4 inputs.
- // Load inputs for the top two filters first.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- const uint8* ptr = input_ptr;
-
- // Load top 3 rows.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- // Multiply-accum for top-left output.
- acc_0 = MultiplyAccumulate3x3Filter(filter, input_0, input_1, input_2,
- input_4, input_5, input_6, input_8,
- input_9, input_10, acc_0);
-
- // Multiply-accum for top-right output.
- acc_1 = MultiplyAccumulate3x3Filter(filter, input_1, input_2, input_3,
- input_5, input_6, input_7, input_9,
- input_10, input_11, acc_1);
-
- // Now load the bottom row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- }
-
- // Multiply-accum for bottom-left output.
- acc_2 = MultiplyAccumulate3x3Filter(filter, input_4, input_5, input_6,
- input_8, input_9, input_10, input_0,
- input_1, input_2, acc_2);
-
- // Multiply-accum for bottom-right output.
- acc_3 = MultiplyAccumulate3x3Filter(filter, input_5, input_6, input_7,
- input_9, input_10, input_11, input_1,
- input_2, input_3, acc_3);
- }
-
- DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset,
- output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- const int output_row_size = output_depth * output_width;
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the top right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + 4 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
-
- // Now load next inputs when sliding window down.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8,
- input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth + output_row_size,
- output_depth);
-
- // Now load next inputs when sliding window left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10,
- input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + output_row_size, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<1, 4, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the left.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
-
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
-
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth);
-
- // Now load 1x2 inputs on the right.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr + input_depth * 4;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_2 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
-
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- DotProductAndStore2xStride1(
- filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4,
- input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr + 2 * output_depth, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 1, 1, 1> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- // To process 2x1 outputs using a 3x3 filter, we require 4x3 inputs.
- // Load all inputs at the beginning.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11;
-
- // Load inputs for 1x2 outputs starting from the top left.
- {
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5;
-
- const uint8* ptr = input_ptr;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- input_10 = vaddq_s16(input_10, input_offset_vec);
- input_11 = vaddq_s16(input_11, input_offset_vec);
- }
-
- DotProductAndStore2yStride1(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth * output_width);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9;
-
- const uint8* ptr = input_ptr;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- // Load first 2 rows.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next 2 rows.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
-
- output_ptr += output_row_size;
-
- // Moving onto the next row of outputs.
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_0.high = vld1q_s32(bias_ptr + 4);
- acc_1.high = vld1q_s32(bias_ptr + 4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load last row.
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- DownquantizeAndStore2Output(
- acc_0, acc_1, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- // Reuse 4x2 kernel twice.
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth,
- output_width);
-
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr + 2 * output_depth, output_depth, output_width);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<4, 1, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
+// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
+// Jetson TX-2. This compiler does not support the offsetof() macro.
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- DotProductAndStore(
- filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3,
- input_4, input_5, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Third output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
-
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
-
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
-
- DotProductAndStore(
- filter, input_3, input_4, input_5, input_6, input_7, input_8, input_0,
- input_1, input_2, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Fourth output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
- }
-};
-
-template <>
-struct ConvKernel3x3FilterDepth8<2, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- Int32x8 acc_0, acc_1, acc_2, acc_3;
- acc_0.low = vld1q_s32(bias_ptr);
- acc_1.low = vld1q_s32(bias_ptr);
- acc_2.low = vld1q_s32(bias_ptr);
- acc_3.low = vld1q_s32(bias_ptr);
-
- bias_ptr += 4;
- acc_0.high = vld1q_s32(bias_ptr);
- acc_1.high = vld1q_s32(bias_ptr);
- acc_2.high = vld1q_s32(bias_ptr);
- acc_3.high = vld1q_s32(bias_ptr);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
-
- // Add scope for input registers to help the compiler know that it is
- // not needed.
- {
- // To process 2x2 outputs using a 3x3 filter at stride 2, we require
- // 5x5 inputs. We load the first 5x2 inputs at a time.
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, input_9;
-
- const uint8* ptr = input_ptr;
-
- // Load inputs.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load next inputs.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_9 = vaddq_s16(input_9, input_offset_vec);
- }
-
- acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
-
- // Moving onto the two bottom outputs.
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f0, filter.f1, filter.f2,
- input_0, input_1, input_2);
-
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f0, filter.f1, filter.f2,
- input_2, input_3, input_4);
+// clang-format gets confused with this file and ends up formatting lines to
+// be larger than 80 characters. Turn off here and back on at the end of the
+// file.
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f3, filter.f4, filter.f5,
- input_5, input_6, input_7);
+// clang-format off
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f3, filter.f4, filter.f5,
- input_7, input_8, input_9);
-
- // Load last input row.
- {
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- temp_3 = vld1_u8(ptr + 3 * input_depth);
- temp_4 = vld1_u8(ptr + 4 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- }
-
- acc_2 = MultiplyAccumulateRow(acc_2, filter.f6, filter.f7, filter.f8,
- input_0, input_1, input_2);
-
- acc_3 = MultiplyAccumulateRow(acc_3, filter.f6, filter.f7, filter.f8,
- input_2, input_3, input_4);
- }
+#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64
- DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset,
- output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
- }
+// Encapsulates constant parameters used in DepthwiseConv.
+// 64-bit is used for types that will be added to 64-bit addresses in asm.
+struct DepthwiseConvParams {
+ int64_t input_depth;
+ int64_t input_row_size;
+ int64_t output_depth;
+ int64_t output_row_size;
+ int32 input_offset;
+ int32 output_offset;
+ int32 filter_offset;
+ int32 output_multiplier;
+ int32 output_activation_min;
+ int32 output_activation_max;
+ int32 output_shift;
+ int32 input_width;
+ int32 input_height;
+ int32 output_width;
+ int32 output_height;
};
-template <>
-struct ConvKernel3x3FilterDepth8<2, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- // Reuse 2x2 kernel twice.
- ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr, output_depth,
- output_width);
-
- ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run(
- input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr + 2 * output_depth, output_depth, output_width);
- }
-};
+#define STR(s) STR_UNEXPANDED(s)
+#define STR_UNEXPANDED(s) #s
+
+// Represents the number of bytes offset from the start of the
+// DepthwiseConvParams struct. This is used in the asm to load parameters.
+// Keep these values in sync with the static_asserts below.
+#define OFFSET_INPUT_DEPTH 0
+#define OFFSET_INPUT_ROW_SIZE 8
+#define OFFSET_OUTPUT_DEPTH 16
+#define OFFSET_OUTPUT_ROW_SIZE 24
+#define OFFSET_INPUT_OFFSET 32
+#define OFFSET_OUTPUT_OFFSET 36
+#define OFFSET_FILTER_OFFSET 40
+#define OFFSET_OUTPUT_MULTIPLIER 44
+#define OFFSET_OUTPUT_ACTIVATION_MIN 48
+#define OFFSET_OUTPUT_ACTIVATION_MAX 52
+#define OFFSET_OUTPUT_SHIFT 56
+#define OFFSET_INPUT_WIDTH 60
+#define OFFSET_INPUT_HEIGHT 64
+#define OFFSET_OUTPUT_WIDTH 68
+#define OFFSET_OUTPUT_HEIGHT 72
+
+static_assert(offsetof(DepthwiseConvParams, input_depth) ==
+ OFFSET_INPUT_DEPTH, "");
+static_assert(offsetof(DepthwiseConvParams, input_row_size) ==
+ OFFSET_INPUT_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, output_depth) ==
+ OFFSET_OUTPUT_DEPTH, "");
+static_assert(offsetof(DepthwiseConvParams, output_row_size) ==
+ OFFSET_OUTPUT_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, input_offset) ==
+ OFFSET_INPUT_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, output_offset) ==
+ OFFSET_OUTPUT_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, filter_offset) ==
+ OFFSET_FILTER_OFFSET, "");
+static_assert(offsetof(DepthwiseConvParams, output_multiplier) ==
+ OFFSET_OUTPUT_MULTIPLIER, "");
+static_assert(offsetof(DepthwiseConvParams, output_activation_min) ==
+ OFFSET_OUTPUT_ACTIVATION_MIN, "");
+static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
+ OFFSET_OUTPUT_ACTIVATION_MAX, "");
+static_assert(offsetof(DepthwiseConvParams, output_shift) ==
+ OFFSET_OUTPUT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, input_width) ==
+ OFFSET_INPUT_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, input_height) ==
+ OFFSET_INPUT_HEIGHT, "");
+static_assert(offsetof(DepthwiseConvParams, output_width) ==
+ OFFSET_OUTPUT_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, output_height) ==
+ OFFSET_OUTPUT_HEIGHT, "");
+
+template <int32 kDepth, int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvWindow {};
template <>
-struct ConvKernel3x3FilterDepth8<2, 1, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- const int output_row_size = output_depth * output_width;
-
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_row_size;
-
- ptr += input_row_size;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
-
- DotProductAndStore(
- filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3,
- input_4, input_5, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
+struct DepthwiseConvWindow<8, 1, 1> {
+ public:
+ static void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width,
+ const DepthwiseConvParams* params_ptr) {
+ const int64_t input_width_increment = 2 * input_depth;
+ const int64_t input_height_increment = 2 * input_row_size;
+ const int64_t output_height_increment = 2 * params_ptr->output_row_size;
+
+#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6"
+#define DEPTHWISECONV_LABEL_HEIGHT_1 "7"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11"
+
+ asm volatile(
+ // Performs depthwise convolutions for a window specified by
+ // |output_window_height| and |output_window_width|. The inner-most loop
+ // processes 2x2 outputs, and any leftovers at the end.
+ //
+ // Algorithm works as follows:
+ //
+ // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter
+ // values.
+ // 2. For 2 output heights at a time:
+ // i. For 2 output widths at a time, load inputs for a 2x1 (2
+ // height, 1 width) output window (4x3 input window).
+ // Registers v9--v20 hold input values. Mul-add with
+ // accumulators v21--v24. Then run activation, downquantize
+ // and store. Repeat for the next 2x1 output window,
+ // leveraging overlapping inputs.
+ // ii. Handle single leftover width if exists.
+ // 3. Handle single leftover height if exists.
+ // i. For 2 output widths at a time, load inputs for a 1x2 (1
+ // height, 2 width) output window (3x4 input window).
+ // Registers v9--v20 hold input values. Mul-add with
+ // accumulators v21--v24. Then run activation, downquantize
+ // and store. Repeat for the next 1x2 output window,
+ // leveraging overlapping inputs.
+ // ii. Handle single leftover width if exists.
+ //
+ // Loads are placed as soon as the register is no longer needed and
+ // interleaved with arithmetic operations to take advantage of
+ // dual-issue pipelines. We also add input offsets as far from the loads
+ // as possible to give loads enough cycles to fetch data from memory.
+
+ // Set "constant" registers. These registers may be replaced with temp
+ // values from time to time when there are not enough NEON registers.
+ // We use x9--x15 general purpose registers as they are caller-saved
+ // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "cmp %w[output_window_height], #2\n"
+ "dup v26.8h, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "dup v29.4s, w2\n"
+ "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "dup v30.4s, w4\n"
+ "ldr w0, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v31.4s, w0\n"
+ "neg w9, w9\n"
+ "dup v28.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "add x10, %[bias_ptr], #16\n"
+ "ldr x1, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n"
+ "dup v9.8h, w9\n"
+
+ // Load filters and add offsets.
+ "ld1 {v0.8b}, [%[filter_ptr]], x3\n"
+ "ld1 {v1.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v0.8h, v9.8h, v0.8b\n"
+ "ld1 {v2.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v1.8h, v9.8h, v1.8b\n"
+ "ld1 {v3.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v2.8h, v9.8h, v2.8b\n"
+ "ld1 {v4.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v3.8h, v9.8h, v3.8b\n"
+ "ld1 {v5.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v4.8h, v9.8h, v4.8b\n"
+ "ld1 {v6.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v5.8h, v9.8h, v5.8b\n"
+ "ld1 {v7.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v6.8h, v9.8h, v6.8b\n"
+ "ld1 {v8.8b}, [%[filter_ptr]], x3\n"
+ "uaddw v7.8h, v9.8h, v7.8b\n"
+ "uaddw v8.8h, v9.8h, v8.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n"
+ // This loop processes 2x2 outputs. To avoid register exhaustion,
+ // inputs for the left 2 outputs are loaded first, then the right
+ // two outputs.
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x13, x11, %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "add x14, x13, %[input_row_size]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x14, %[input_row_size]\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "mov w5, %w[output_window_width]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x1\n"
+ "ld1 {v15.8b}, [x14], %[input_depth]\n"
+ // The height 2 / width 2 loop loads an extra 2x1 outputs (2 height,
+ // 1 width) in anticipation for the next iteration. Make sure
+ // |output_window_width| is large enough to handle the additional
+ // loads, otherwise jump to specific the appropriate label to handle
+ // smaller widths.
+ "cmp w5, #2\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v16.8b}, [x14], %[input_depth]\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "ld1 {v18.8b}, [x15], %[input_depth]\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "ld1 {v19.8b}, [x15], %[input_depth]\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "ld1 {v20.8b}, [x15], %[input_depth]\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n"
+ "cmp w5, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n"
+ // Mul-add left outputs.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "subs w5, w5, #2\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "cmp w5, #3\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13]\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x15]\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+
+ // Mul-add right outputs.
+ "smlal v21.4s, v0.4h, v10.4h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal2 v22.4s, v0.8h, v10.8h\n"
+ "mov x12, x11\n"
+ "smlal v23.4s, v0.4h, v13.4h\n"
+ "add x13, x11, %[input_row_size]\n"
+ "smlal2 v24.4s, v0.8h, v13.8h\n"
+ "add x14, x13, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v11.4h\n"
+ "add x15, x14, %[input_row_size]\n"
+ "smlal2 v22.4s, v1.8h, v11.8h\n"
+ "smlal v23.4s, v1.4h, v14.4h\n"
+ "smlal2 v24.4s, v1.8h, v14.8h\n"
+ "smlal v21.4s, v2.4h, v9.4h\n"
+ "smlal2 v22.4s, v2.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "smlal v21.4s, v5.4h, v12.4h\n"
+ "smlal2 v22.4s, v5.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v15.4h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v5.8h, v15.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v6.4h, v16.4h\n"
+ "smlal2 v22.4s, v6.8h, v16.8h\n"
+ "smlal v23.4s, v6.4h, v19.4h\n"
+ "smlal2 v24.4s, v6.8h, v19.8h\n"
+ "smlal v21.4s, v7.4h, v17.4h\n"
+ "smlal2 v22.4s, v7.8h, v17.8h\n"
+ "smlal v23.4s, v7.4h, v20.4h\n"
+ "smlal2 v24.4s, v7.8h, v20.8h\n"
+ "smlal v21.4s, v8.4h, v15.4h\n"
+ "smlal2 v22.4s, v8.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v18.4h\n"
+ "ld1 {v16.8b}, [x14], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v18.8h\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "ld1 {v18.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "ld1 {v19.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "ld1 {v20.8b}, [x15], %[input_depth]\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w5, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last 2 columns if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n"
+ // Mul-add left outputs.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13]\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x14]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x15]\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+
+ // Mul-add right outputs.
+ "smlal v21.4s, v0.4h, v10.4h\n"
+ "smlal2 v22.4s, v0.8h, v10.8h\n"
+ "smlal v23.4s, v0.4h, v13.4h\n"
+ "smlal2 v24.4s, v0.8h, v13.8h\n"
+ "smlal v21.4s, v1.4h, v11.4h\n"
+ "smlal2 v22.4s, v1.8h, v11.8h\n"
+ "smlal v23.4s, v1.4h, v14.4h\n"
+ "smlal2 v24.4s, v1.8h, v14.8h\n"
+ "smlal v21.4s, v2.4h, v9.4h\n"
+ "smlal2 v22.4s, v2.8h, v9.8h\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "smlal v21.4s, v5.4h, v12.4h\n"
+ "smlal2 v22.4s, v5.8h, v12.8h\n"
+ "smlal v23.4s, v5.4h, v15.4h\n"
+ "smlal2 v24.4s, v5.8h, v15.8h\n"
+ "smlal v21.4s, v6.4h, v16.4h\n"
+ "smlal2 v22.4s, v6.8h, v16.8h\n"
+ "smlal v23.4s, v6.4h, v19.4h\n"
+ "smlal2 v24.4s, v6.8h, v19.8h\n"
+ "smlal v21.4s, v7.4h, v17.4h\n"
+ "smlal2 v22.4s, v7.8h, v17.8h\n"
+ "smlal v23.4s, v7.4h, v20.4h\n"
+ "smlal2 v24.4s, v7.8h, v20.8h\n"
+ "smlal v21.4s, v8.4h, v15.4h\n"
+ "smlal2 v22.4s, v8.8h, v15.8h\n"
+ "smlal v23.4s, v8.4h, v18.4h\n"
+ "smlal2 v24.4s, v8.8h, v18.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "st1 {v23.8b}, [x7], x3\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v14.4h\n"
+ "smlal2 v24.4s, v2.8h, v14.8h\n"
+ "smlal v21.4s, v3.4h, v12.4h\n"
+ "smlal2 v22.4s, v3.8h, v12.8h\n"
+ "smlal v23.4s, v3.4h, v15.4h\n"
+ "smlal2 v24.4s, v3.8h, v15.8h\n"
+ "smlal v21.4s, v4.4h, v13.4h\n"
+ "smlal2 v22.4s, v4.8h, v13.8h\n"
+ "smlal v23.4s, v4.4h, v16.4h\n"
+ "smlal2 v24.4s, v4.8h, v16.8h\n"
+ "smlal v21.4s, v5.4h, v14.4h\n"
+ "smlal2 v22.4s, v5.8h, v14.8h\n"
+ "smlal v23.4s, v5.4h, v17.4h\n"
+ "smlal2 v24.4s, v5.8h, v17.8h\n"
+ "smlal v21.4s, v6.4h, v15.4h\n"
+ "smlal2 v22.4s, v6.8h, v15.8h\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v16.4h\n"
+ "smlal2 v22.4s, v7.8h, v16.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v9.16b, v21.16b, v28.16b\n"
+ "and v12.16b, v22.16b, v28.16b\n"
+ "and v15.16b, v23.16b, v28.16b\n"
+ "and v18.16b, v24.16b, v28.16b\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v12.4s, v12.4s, #31\n"
+ "sshr v15.4s, v15.4s, #31\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v12.4s\n"
+ "sqadd v23.4s, v23.4s, v15.4s\n"
+ "sqadd v24.4s, v24.4s, v18.4s\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [x6], x3\n"
+ "st1 {v23.8b}, [x7], x3\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n"
+ "subs %w[output_window_height], %w[output_window_height], #2\n"
+ "add %[input_ptr], %[input_ptr], %[input_height_increment]\n"
+ "cmp %w[output_window_height], #2\n"
+ "add %[output_ptr], %[output_ptr], %[output_height_increment]\n"
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n"
+ "cmp %w[output_window_height], #1\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1 ":\n"
+ "mov x12, %[input_ptr]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x13, %[input_ptr], %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "add x14, x13, %[input_row_size]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x14, %[input_row_size]\n"
+ "mov w5, %w[output_window_width]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x1\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ // The height 1 / width 2 loop loads an extra 1x1 output in anticipation
+ // for the next iteration. Make sure |output_window_width| is large
+ // enough to handle the additional load, otherwise jump to the
+ // appropriate label to handle smaller widths.
+ "cmp w5, #2\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "ld1 {v18.8b}, [x14], %[input_depth]\n"
+ "ld1 {v19.8b}, [x14], %[input_depth]\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "ld1 {v24.4s}, [x10]\n"
+
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n"
+ "cmp w5, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n"
+ // Load inputs for 3x4 input window which corresponds to a 1x2 output
+ // window.
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v16.8b}, [x13]\n"
+ "smlal v23.4s, v0.4h, v10.4h\n"
+ "ld1 {v20.8b}, [x14]\n"
+ "smlal2 v24.4s, v0.8h, v10.8h\n"
+ "subs w5, w5, #2\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "cmp w5, #3\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "add %[input_ptr], %[input_ptr], %[input_width_increment]\n"
+ "smlal v23.4s, v1.4h, v11.4h\n"
+ "mov x12, %[input_ptr]\n"
+ "smlal2 v24.4s, v1.8h, v11.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x13, %[input_ptr], %[input_row_size]\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "add x14, x13, %[input_row_size]\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "add x15, x14, %[input_row_size]\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v14.4h\n"
+ "smlal2 v24.4s, v3.8h, v14.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v4.4h, v15.4h\n"
+ "smlal2 v24.4s, v4.8h, v15.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v16.4h\n"
+ "smlal2 v24.4s, v5.8h, v16.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+ "ld1 {v19.8b}, [x14], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "st1 {v21.8b}, [%[output_ptr]], x3\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "st1 {v23.8b}, [%[output_ptr]], x3\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+ "uaddw v14.8h, v26.8h, v14.8b\n"
+ "uaddw v15.8h, v26.8h, v15.8b\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v17.8h, v26.8h, v17.8b\n"
+ "uaddw v18.8h, v26.8h, v18.8b\n"
+ "uaddw v19.8h, v26.8h, v19.8b\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w5, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last two horizontal outputs if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v0.4h, v10.4h\n"
+ "ld1 {v20.8b}, [x14], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v10.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v23.4s, v1.4h, v11.4h\n"
+ "smlal2 v24.4s, v1.8h, v11.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v23.4s, v2.4h, v12.4h\n"
+ "smlal2 v24.4s, v2.8h, v12.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v23.4s, v3.4h, v14.4h\n"
+ "smlal2 v24.4s, v3.8h, v14.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v23.4s, v4.4h, v15.4h\n"
+ "smlal2 v24.4s, v4.8h, v15.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "uaddw v16.8h, v26.8h, v16.8b\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "smlal v23.4s, v5.4h, v16.4h\n"
+ "smlal2 v24.4s, v5.8h, v16.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "smlal v23.4s, v6.4h, v18.4h\n"
+ "smlal2 v24.4s, v6.8h, v18.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "smlal v23.4s, v7.4h, v19.4h\n"
+ "smlal2 v24.4s, v7.8h, v19.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "uaddw v20.8h, v26.8h, v20.8b\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+ "smlal v23.4s, v8.4h, v20.4h\n"
+ "smlal2 v24.4s, v8.8h, v20.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v25.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v25.4s, v25.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v25.4s\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w4\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w0\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v21.8b}, [%[output_ptr]], x3\n"
+ "st1 {v23.8b}, [%[output_ptr]], x3\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ // Handle bottom right output if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "smlal v21.4s, v3.4h, v13.4h\n"
+ "smlal2 v22.4s, v3.8h, v13.8h\n"
+ "smlal v21.4s, v4.4h, v14.4h\n"
+ "smlal2 v22.4s, v4.8h, v14.8h\n"
+ "smlal v21.4s, v5.4h, v15.4h\n"
+ "smlal2 v22.4s, v5.8h, v15.8h\n"
+ "smlal v21.4s, v6.4h, v17.4h\n"
+ "smlal2 v22.4s, v6.8h, v17.8h\n"
+ "smlal v21.4s, v7.4h, v18.4h\n"
+ "smlal2 v22.4s, v7.8h, v18.8h\n"
+ "smlal v21.4s, v8.4h, v19.4h\n"
+ "smlal2 v22.4s, v8.8h, v19.8h\n"
+
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "and v9.16b, v21.16b, v28.16b\n"
+ "and v12.16b, v22.16b, v28.16b\n"
+ "sshr v9.4s, v9.4s, #31\n"
+ "sshr v12.4s, v12.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v9.4s\n"
+ "sqadd v22.4s, v22.4s, v12.4s\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "st1 {v21.8b}, [%[output_ptr]]\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr),
+ [output_window_height] "+r"(output_window_height)
+ :
+ // Inputs.
+ [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size),
+ [input_depth] "r"(input_depth),
+ [output_window_width] "r"(output_window_width),
+ [input_width_increment] "r"(input_width_increment),
+ [input_height_increment] "r"(input_height_increment),
+ [output_height_increment] "r"(output_height_increment),
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_END
}
};
template <>
-struct ConvKernel3x3FilterDepth8<1, 2, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
- const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 3 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- DotProductAndStore(
- filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8,
- input_6, input_7, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
+struct DepthwiseConvWindow<8, 2, 2> {
+ static void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width,
+ const DepthwiseConvParams* params_ptr) {
+ const int64_t input_width_increment = 4 * input_depth;
+ const int64_t input_height_increment = 4 * input_row_size;
+ const int64_t output_height_increment = 2 * params_ptr->output_row_size;
+
+#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5"
+#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6"
+#define DEPTHWISECONV_LABEL_HEIGHT_1 "7"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10"
+#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11"
+
+ asm volatile(
+ // Performs depthwise convolutions for a window specified by
+ // |output_window_height| and |output_window_width|. The inner-most loop
+ // processes 2x2 outputs, and any leftovers at the end.
+ //
+ // Algorithm works as follows:
+ //
+ // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter
+ // values.
+ // 2. For 2 output heights at a time:
+ // i. For 2 output widths at a time at stride 2, a 5x5 input
+ // window is required. To avoid register exhaustion, we load
+ // the first 2 rows of the 5x5 input window into registers
+ // v9--v18, and use the same registers to load the next 2
+ // rows, and finally v9--v13 to load the last row.
+ // Accumulators for all 2x2 outputs are reserved by registers
+ // v21-v22 (top left output), v23-v24 (top right output),
+ // v19-v20 (bottom left output), v25-v26 (bottom right
+ // output).
+ // ii. Handle single leftover width if exists.
+ // 3. Handle single leftover height if exists.
+ // i. For 2 output widths at a time at stride 2, load inputs for
+ // a 1x2 (1 height, 2 width) output window (3x5 input
+ // window). Registers v9--v24 hold input values. Mul-add with
+ // accumulators v24--v27.
+ // ii. Handle single leftover width if exists.
+ //
+ // Loads are placed as soon as the register is no longer needed and
+ // interleaved with arithmetic operations to take advantage of
+ // dual-issue pipelines. We also add input offsets as far from the loads
+ // as possible to give loads enough cycles to fetch data from memory.
+
+ // Set "constant" registers. These registers may be replaced with temp
+ // values from time to time when there are not enough NEON registers.
+ // We use x9--x15 general purpose registers as they are caller-saved
+ // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "cmp %w[output_window_height], #2\n"
+ "dup v28.8h, w0\n"
+ "neg w9, w9\n"
+ "ldr w1, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.4s, w9\n"
+ "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w1\n"
+ "ldr w3, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "dup v29.4s, w2\n"
+ "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w3\n"
+ "ldr x5, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "dup v31.4s, w4\n"
+ "ldr x19, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n"
+ "ldr w20, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+
+ // Load filters and add offsets.
+ "add x10, %[bias_ptr], #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], x5\n"
+ "dup v9.8h, w20\n"
+ "ld1 {v1.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v0.8h, v9.8h, v0.8b\n"
+ "ld1 {v2.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v1.8h, v9.8h, v1.8b\n"
+ "ld1 {v3.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v2.8h, v9.8h, v2.8b\n"
+ "ld1 {v4.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v3.8h, v9.8h, v3.8b\n"
+ "ld1 {v5.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v4.8h, v9.8h, v4.8b\n"
+ "ld1 {v6.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v5.8h, v9.8h, v5.8b\n"
+ "ld1 {v7.8b}, [%[filter_ptr]], x5\n"
+ "uaddw v6.8h, v9.8h, v6.8b\n"
+ "ld1 {v8.8b}, [%[filter_ptr]]\n"
+ "uaddw v7.8h, v9.8h, v7.8b\n"
+ "uaddw v8.8h, v9.8h, v8.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n"
+ // Load the first two rows of the 5x5 input window, then reuse the
+ // same registers to load subsequent rows as they become available.
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "add x13, x12, %[input_row_size]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "mov w14, %w[output_window_width]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ // The height 2 / width 2 loop loads an extra 1 output horizontally in
+ // anticipation for the next iteration. Make sure
+ // |output_window_width| is large enough to handle the additional
+ // load, otherwise jump to the appropriate label to handle smaller
+ // widths.
+ "cmp w14, #2\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "add x15, x13, %[input_row_size]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "add x7, %[output_ptr], x19\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "ld1 {v19.4s}, [%[bias_ptr]]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v20.4s}, [x10]\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v25.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "ld1 {v26.4s}, [x10]\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n"
+ "cmp w14, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v13.8b}, [x12]\n"
+ "add x12, x15, %[input_row_size]\n"
+ "smlal v23.4s, v0.4h, v11.4h\n"
+ "ld1 {v17.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v11.8h\n"
+ "ld1 {v18.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "subs w14, w14, #2\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "cmp w14, #3\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v1.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v1.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v13.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v24.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x15]\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v18.4h\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "smlal2 v24.4s, v5.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x12]\n"
+
+ "smlal v21.4s, v6.4h, v9.4h\n"
+ "smlal2 v22.4s, v6.8h, v9.8h\n"
+ "smlal v19.4s, v0.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v20.4s, v0.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v11.4h\n"
+ "smlal2 v24.4s, v6.8h, v11.8h\n"
+ "smlal v21.4s, v7.4h, v10.4h\n"
+ "smlal2 v22.4s, v7.8h, v10.8h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal v19.4s, v1.4h, v10.4h\n"
+ "smlal2 v20.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v12.4h\n"
+ "smlal2 v24.4s, v7.8h, v12.8h\n"
+ "smlal v25.4s, v1.4h, v12.4h\n"
+ "smlal2 v26.4s, v1.8h, v12.8h\n"
+ "smlal v21.4s, v8.4h, v11.4h\n"
+ "smlal2 v22.4s, v8.8h, v11.8h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal v19.4s, v2.4h, v11.4h\n"
+ "mov x12, x11\n"
+ "smlal2 v20.4s, v2.8h, v11.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal v25.4s, v0.4h, v11.4h\n"
+ "smlal2 v26.4s, v0.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v13.4h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v13.8h\n"
+ "smlal v25.4s, v2.4h, v13.4h\n"
+ "smlal2 v26.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "add x15, x13, %[input_row_size]\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v27.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v23.8b}, [x6], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+
+ "smlal v19.4s, v6.4h, v9.4h\n"
+ "smlal2 v20.4s, v6.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v6.4h, v11.4h\n"
+ "smlal2 v26.4s, v6.8h, v11.8h\n"
+ "smlal v19.4s, v7.4h, v10.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v20.4s, v7.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v7.4h, v12.4h\n"
+ "smlal2 v26.4s, v7.8h, v12.8h\n"
+ "smlal v19.4s, v8.4h, v11.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v20.4s, v8.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal v25.4s, v8.4h, v13.4h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal2 v26.4s, v8.8h, v13.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v19.4s, v3.4h, v14.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v20.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v3.4h, v16.4h\n"
+ "ld1 {v21.4s}, [%[bias_ptr]]\n"
+ "smlal2 v26.4s, v3.8h, v16.8h\n"
+ "ld1 {v23.4s}, [%[bias_ptr]]\n"
+ "smlal v19.4s, v4.4h, v15.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v20.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v4.4h, v17.4h\n"
+ "smlal2 v26.4s, v4.8h, v17.8h\n"
+ "smlal v19.4s, v5.4h, v16.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v20.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x13], %[input_depth]\n"
+ "smlal v25.4s, v5.4h, v18.4h\n"
+ "smlal2 v26.4s, v5.8h, v18.8h\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v19.4s, v19.4s, v27.4s\n"
+ "sqrdmulh v20.4s, v20.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v27.4s\n"
+ "and v27.16b, v19.16b, v28.16b\n"
+ "and v29.16b, v20.16b, v28.16b\n"
+ "and v30.16b, v25.16b, v28.16b\n"
+ "and v31.16b, v26.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v19.4s, v19.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v20.4s, v20.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v25.4s, v25.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v26.4s, v26.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v19.4s, v19.4s, v28.4s\n"
+ "srshl v20.4s, v20.4s, v28.4s\n"
+ "srshl v25.4s, v25.4s, v28.4s\n"
+ "srshl v26.4s, v26.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v19.4s, v19.4s, v29.4s\n"
+ "add v20.4s, v20.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "add v26.4s, v26.4s, v29.4s\n"
+ "smax v19.4s, v19.4s, v30.4s\n"
+ "smax v20.4s, v20.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smin v19.4s, v19.4s, v31.4s\n"
+ "smin v20.4s, v20.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "sqxtn v19.4h, v19.4s\n"
+ "sqxtn v25.4h, v25.4s\n"
+ "sqxtn2 v19.8h, v20.4s\n"
+ "ld1 {v20.4s}, [x10]\n"
+ "sqxtn2 v25.8h, v26.4s\n"
+ "ld1 {v26.4s}, [x10]\n"
+ "sqxtun v19.8b, v19.8h\n"
+ "sqxtun v25.8b, v25.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v19.8b}, [x7], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v25.8b}, [x7], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v19.4s}, [%[bias_ptr]]\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v25.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w14, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last 2 columns if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v12.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v13.8b}, [x12]\n"
+ "add x12, x15, %[input_row_size]\n"
+ "smlal v23.4s, v0.4h, v11.4h\n"
+ "ld1 {v17.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v0.8h, v11.8h\n"
+ "ld1 {v18.8b}, [x13]\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v3.4h, v16.4h\n"
+ "smlal2 v24.4s, v3.8h, v16.8h\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v1.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v1.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v23.4s, v2.4h, v13.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v24.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x15]\n"
+ "smlal v23.4s, v4.4h, v17.4h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "smlal2 v24.4s, v4.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x12], %[input_depth]\n"
+ "smlal v23.4s, v5.4h, v18.4h\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "smlal2 v24.4s, v5.8h, v18.8h\n"
+ "ld1 {v18.8b}, [x12]\n"
+
+ "smlal v21.4s, v6.4h, v9.4h\n"
+ "smlal2 v22.4s, v6.8h, v9.8h\n"
+ "smlal v19.4s, v0.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v20.4s, v0.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v6.4h, v11.4h\n"
+ "smlal2 v24.4s, v6.8h, v11.8h\n"
+ "smlal v21.4s, v7.4h, v10.4h\n"
+ "smlal2 v22.4s, v7.8h, v10.8h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal v19.4s, v1.4h, v10.4h\n"
+ "smlal2 v20.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v7.4h, v12.4h\n"
+ "smlal2 v24.4s, v7.8h, v12.8h\n"
+ "smlal v25.4s, v1.4h, v12.4h\n"
+ "smlal2 v26.4s, v1.8h, v12.8h\n"
+ "smlal v21.4s, v8.4h, v11.4h\n"
+ "smlal2 v22.4s, v8.8h, v11.8h\n"
+ "smlal v19.4s, v2.4h, v11.4h\n"
+ "smlal2 v20.4s, v2.8h, v11.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal v25.4s, v0.4h, v11.4h\n"
+ "smlal2 v26.4s, v0.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], %[input_depth]\n"
+ "smlal v23.4s, v8.4h, v13.4h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal2 v24.4s, v8.8h, v13.8h\n"
+ "smlal v25.4s, v2.4h, v13.4h\n"
+ "smlal2 v26.4s, v2.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v27.16b, v21.16b, v28.16b\n"
+ "and v29.16b, v22.16b, v28.16b\n"
+ "and v30.16b, v23.16b, v28.16b\n"
+ "and v31.16b, v24.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v22.4s, v22.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v23.4s, v23.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v24.4s, v24.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v21.4s, v21.4s, v28.4s\n"
+ "srshl v22.4s, v22.4s, v28.4s\n"
+ "srshl v23.4s, v23.4s, v28.4s\n"
+ "srshl v24.4s, v24.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "ld1 {v22.4s}, [x10]\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "ld1 {v24.4s}, [x10]\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v23.8b}, [x6]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+
+ "smlal v19.4s, v6.4h, v9.4h\n"
+ "smlal2 v20.4s, v6.8h, v9.8h\n"
+ "smlal v25.4s, v6.4h, v11.4h\n"
+ "smlal2 v26.4s, v6.8h, v11.8h\n"
+ "smlal v19.4s, v7.4h, v10.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v20.4s, v7.8h, v10.8h\n"
+ "smlal v25.4s, v7.4h, v12.4h\n"
+ "smlal2 v26.4s, v7.8h, v12.8h\n"
+ "smlal v19.4s, v8.4h, v11.4h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "smlal2 v20.4s, v8.8h, v11.8h\n"
+ "smlal v25.4s, v8.4h, v13.4h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal2 v26.4s, v8.8h, v13.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v19.4s, v3.4h, v14.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v20.4s, v3.8h, v14.8h\n"
+ "smlal v25.4s, v3.4h, v16.4h\n"
+ "smlal2 v26.4s, v3.8h, v16.8h\n"
+ "smlal v19.4s, v4.4h, v15.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v20.4s, v4.8h, v15.8h\n"
+ "smlal v25.4s, v4.4h, v17.4h\n"
+ "smlal2 v26.4s, v4.8h, v17.8h\n"
+ "smlal v19.4s, v5.4h, v16.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v20.4s, v5.8h, v16.8h\n"
+ "smlal v25.4s, v5.4h, v18.4h\n"
+ "smlal2 v26.4s, v5.8h, v18.8h\n"
+
+ "dup v28.4s, w9\n"
+ "sqrdmulh v19.4s, v19.4s, v27.4s\n"
+ "sqrdmulh v20.4s, v20.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v27.4s\n"
+ "and v27.16b, v19.16b, v28.16b\n"
+ "and v29.16b, v20.16b, v28.16b\n"
+ "and v30.16b, v25.16b, v28.16b\n"
+ "and v31.16b, v26.16b, v28.16b\n"
+ "sshr v27.4s, v27.4s, #31\n"
+ "sshr v29.4s, v29.4s, #31\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v19.4s, v19.4s, v27.4s\n"
+ "dup v27.4s, w1\n"
+ "sqadd v20.4s, v20.4s, v29.4s\n"
+ "dup v29.4s, w2\n"
+ "sqadd v25.4s, v25.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v26.4s, v26.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v19.4s, v19.4s, v28.4s\n"
+ "srshl v20.4s, v20.4s, v28.4s\n"
+ "srshl v25.4s, v25.4s, v28.4s\n"
+ "srshl v26.4s, v26.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "add v19.4s, v19.4s, v29.4s\n"
+ "add v20.4s, v20.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "add v26.4s, v26.4s, v29.4s\n"
+ "smax v19.4s, v19.4s, v30.4s\n"
+ "smax v20.4s, v20.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smin v19.4s, v19.4s, v31.4s\n"
+ "smin v20.4s, v20.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "sqxtn v19.4h, v19.4s\n"
+ "sqxtn v25.4h, v25.4s\n"
+ "sqxtn2 v19.8h, v20.4s\n"
+ "sqxtn2 v25.8h, v26.4s\n"
+ "sqxtun v19.8b, v19.8h\n"
+ "sqxtun v25.8b, v25.8h\n"
+ "st1 {v19.8b}, [x7], x5\n"
+ "st1 {v25.8b}, [x7]\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n"
+
+ // Handle last column if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n"
+ // Registers v9, v10, v11, v14, v15, and v16 have already been loaded
+ // with the correct values at this point. This corresponds to the
+ // first two input rows of the top left output. Now load the last
+ // input row for this output. Once these inputs are no longer needed,
+ // load the input rows for the bottom left output.
+ "add x12, x15, %[input_row_size]\n"
+ "add x13, x12, %[input_row_size]\n"
+
+ "ld1 {v12.8b}, [x15], %[input_depth]\n"
+ "smlal v21.4s, v0.4h, v9.4h\n"
+ "ld1 {v13.8b}, [x15], %[input_depth]\n"
+ "smlal2 v22.4s, v0.8h, v9.8h\n"
+ "ld1 {v17.8b}, [x15]\n"
+ "smlal v21.4s, v1.4h, v10.4h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal2 v22.4s, v1.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v21.4s, v2.4h, v11.4h\n"
+ "smlal2 v22.4s, v2.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x12]\n"
+ "smlal v21.4s, v3.4h, v14.4h\n"
+ "smlal2 v22.4s, v3.8h, v14.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v4.4h, v15.4h\n"
+ "smlal2 v22.4s, v4.8h, v15.8h\n"
+ "ld1 {v15.8b}, [x13], %[input_depth]\n"
+ "smlal v21.4s, v5.4h, v16.4h\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "smlal2 v22.4s, v5.8h, v16.8h\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "ld1 {v16.8b}, [x13]\n"
+
+ "smlal v21.4s, v6.4h, v12.4h\n"
+ "smlal2 v22.4s, v6.8h, v12.8h\n"
+ "smlal v23.4s, v0.4h, v12.4h\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+ "smlal2 v24.4s, v0.8h, v12.8h\n"
+ "smlal v21.4s, v7.4h, v13.4h\n"
+ "smlal2 v22.4s, v7.8h, v13.8h\n"
+ "smlal v23.4s, v1.4h, v13.4h\n"
+ "smlal2 v24.4s, v1.8h, v13.8h\n"
+ "smlal v21.4s, v8.4h, v17.4h\n"
+ "smlal2 v22.4s, v8.8h, v17.8h\n"
+ "smlal v23.4s, v2.4h, v17.4h\n"
+ "smlal2 v24.4s, v2.8h, v17.8h\n"
+
+ "dup v26.4s, w9\n"
+ "sqrdmulh v21.4s, v21.4s, v27.4s\n"
+ "sqrdmulh v22.4s, v22.4s, v27.4s\n"
+ "and v18.16b, v21.16b, v26.16b\n"
+ "and v19.16b, v22.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v21.4s, v21.4s, v18.4s\n"
+ "sqadd v22.4s, v22.4s, v19.4s\n"
+ "srshl v21.4s, v21.4s, v26.4s\n"
+ "srshl v22.4s, v22.4s, v26.4s\n"
+ "add v21.4s, v21.4s, v29.4s\n"
+ "add v22.4s, v22.4s, v29.4s\n"
+ "smax v21.4s, v21.4s, v30.4s\n"
+ "smax v22.4s, v22.4s, v30.4s\n"
+ "smin v21.4s, v21.4s, v31.4s\n"
+ "smin v22.4s, v22.4s, v31.4s\n"
+ "sqxtn v21.4h, v21.4s\n"
+ "sqxtn2 v21.8h, v22.4s\n"
+ "sqxtun v21.8b, v21.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v21.8b}, [x6]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+
+ "smlal v23.4s, v3.4h, v9.4h\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "smlal2 v24.4s, v3.8h, v9.8h\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "smlal v23.4s, v4.4h, v10.4h\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "smlal2 v24.4s, v4.8h, v10.8h\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "smlal v23.4s, v5.4h, v11.4h\n"
+ "smlal2 v24.4s, v5.8h, v11.8h\n"
+
+ "smlal v23.4s, v6.4h, v14.4h\n"
+ "smlal2 v24.4s, v6.8h, v14.8h\n"
+ "smlal v23.4s, v7.4h, v15.4h\n"
+ "smlal2 v24.4s, v7.8h, v15.8h\n"
+ "smlal v23.4s, v8.4h, v16.4h\n"
+ "smlal2 v24.4s, v8.8h, v16.8h\n"
+
+ "sqrdmulh v23.4s, v23.4s, v27.4s\n"
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "and v18.16b, v23.16b, v26.16b\n"
+ "and v19.16b, v24.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v23.4s, v23.4s, v18.4s\n"
+ "sqadd v24.4s, v24.4s, v19.4s\n"
+ "srshl v23.4s, v23.4s, v26.4s\n"
+ "srshl v24.4s, v24.4s, v26.4s\n"
+ "add v23.4s, v23.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "smax v23.4s, v23.4s, v30.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smin v23.4s, v23.4s, v31.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "sqxtn v23.4h, v23.4s\n"
+ "sqxtn2 v23.8h, v24.4s\n"
+ "sqxtun v23.8b, v23.8h\n"
+ "st1 {v23.8b}, [x7]\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n"
+ "subs %w[output_window_height], %w[output_window_height], #2\n"
+ "add %[input_ptr], %[input_ptr], %[input_height_increment]\n"
+ "cmp %w[output_window_height], #2\n"
+ "add %[output_ptr], %[output_ptr], %[output_height_increment]\n"
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n"
+ "cmp %w[output_window_height], #1\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1 ":\n"
+ "mov x11, %[input_ptr]\n"
+ "mov x12, x11\n"
+ "add x13, x12, %[input_row_size]\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "add x15, x13, %[input_row_size]\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "mov x6, %[output_ptr]\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "mov w14, %w[output_window_width]\n"
+ // The height 1 / width 2 loop loads an extra 1x1 output in anticipation
+ // for the next iteration. Make sure |output_window_width| is large
+ // enough to handle the additional load, otherwise jump to the
+ // appropriate label to handle smaller widths.
+ "cmp w14, #2\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "ld1 {v15.8b}, [x15], %[input_depth]\n"
+ "ld1 {v16.8b}, [x15], %[input_depth]\n"
+ "ld1 {v17.8b}, [x15], %[input_depth]\n"
+
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "ld1 {v24.4s}, [%[bias_ptr]]\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "ld1 {v25.4s}, [x10]\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "ld1 {v26.4s}, [%[bias_ptr]]\n"
+ "ld1 {v27.4s}, [x10]\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n"
+ "cmp w14, #1\n"
+ "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n"
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "ld1 {v18.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "ld1 {v19.8b}, [x12]\n"
+ "smlal v26.4s, v0.4h, v11.4h\n"
+ "ld1 {v20.8b}, [x13], %[input_depth]\n"
+ "smlal2 v27.4s, v0.8h, v11.8h\n"
+ "ld1 {v21.8b}, [x13]\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "ld1 {v22.8b}, [x15], %[input_depth]\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "ld1 {v23.8b}, [x15]\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "subs w14, w14, #2\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "cmp w14, #3\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "add x11, x11, %[input_width_increment]\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "mov x12, x11\n"
+ "smlal v26.4s, v3.4h, v14.4h\n"
+ "add x13, x12, %[input_row_size]\n"
+ "smlal2 v27.4s, v3.8h, v14.8h\n"
+ "add x15, x13, %[input_row_size]\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "ld1 {v9.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "ld1 {v10.8b}, [x12], %[input_depth]\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "ld1 {v11.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "ld1 {v12.8b}, [x13], %[input_depth]\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "ld1 {v13.8b}, [x13], %[input_depth]\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "ld1 {v14.8b}, [x13], %[input_depth]\n"
+ "smlal v26.4s, v6.4h, v17.4h\n"
+ "ld1 {v15.8b}, [x15], %[input_depth]\n"
+ "smlal2 v27.4s, v6.8h, v17.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "ld1 {v16.8b}, [x15], %[input_depth]\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+ "ld1 {v17.8b}, [x15], %[input_depth]\n"
+ "uaddw v19.8h, v28.8h, v19.8b\n"
+
+ "smlal v26.4s, v1.4h, v18.4h\n"
+ "uaddw v20.8h, v28.8h, v20.8b\n"
+ "smlal2 v27.4s, v1.8h, v18.8h\n"
+ "smlal v26.4s, v2.4h, v19.4h\n"
+ "uaddw v21.8h, v28.8h, v21.8b\n"
+ "smlal2 v27.4s, v2.8h, v19.8h\n"
+ "smlal v26.4s, v4.4h, v20.4h\n"
+ "smlal v26.4s, v5.4h, v21.4h\n"
+ "smlal2 v27.4s, v4.8h, v20.8h\n"
+ "uaddw v22.8h, v28.8h, v22.8b\n"
+ "smlal2 v27.4s, v5.8h, v21.8h\n"
+ "uaddw v23.8h, v28.8h, v23.8b\n"
+ "smlal v26.4s, v7.4h, v22.4h\n"
+ "smlal2 v27.4s, v7.8h, v22.8h\n"
+ "smlal v26.4s, v8.4h, v23.4h\n"
+ "smlal2 v27.4s, v8.8h, v23.8h\n"
+
+ "dup v28.4s, w1\n"
+ "dup v29.4s, w9\n"
+ "sqrdmulh v24.4s, v24.4s, v28.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v28.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v28.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v28.4s\n"
+ "dup v28.4s, w2\n"
+ "and v30.16b, v24.16b, v29.16b\n"
+ "and v31.16b, v25.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v30.4s\n"
+ "sqadd v25.4s, v25.4s, v31.4s\n"
+ "and v30.16b, v26.16b, v29.16b\n"
+ "and v31.16b, v27.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v26.4s, v26.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v27.4s, v27.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v24.4s, v24.4s, v29.4s\n"
+ "srshl v25.4s, v25.4s, v29.4s\n"
+ "srshl v26.4s, v26.4s, v29.4s\n"
+ "srshl v27.4s, v27.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v28.4s\n"
+ "add v25.4s, v25.4s, v28.4s\n"
+ "add v26.4s, v26.4s, v28.4s\n"
+ "add v27.4s, v27.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smax v27.4s, v27.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "smin v27.4s, v27.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn v26.4h, v26.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "ld1 {v25.4s}, [x10]\n"
+ "sqxtn2 v26.8h, v27.4s\n"
+ "ld1 {v27.4s}, [x10]\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "sqxtun v26.8b, v26.8h\n"
+ "uaddw v9.8h, v28.8h, v9.8b\n"
+ "st1 {v24.8b}, [x6], x5\n"
+ "uaddw v10.8h, v28.8h, v10.8b\n"
+ "st1 {v26.8b}, [x6], x5\n"
+ "uaddw v11.8h, v28.8h, v11.8b\n"
+ "uaddw v12.8h, v28.8h, v12.8b\n"
+ "uaddw v13.8h, v28.8h, v13.8b\n"
+ "uaddw v14.8h, v28.8h, v14.8b\n"
+ "ld1 {v24.4s}, [%[bias_ptr]]\n"
+ "uaddw v15.8h, v28.8h, v15.8b\n"
+ "ld1 {v26.4s}, [%[bias_ptr]]\n"
+ "uaddw v16.8h, v28.8h, v16.8b\n"
+ "uaddw v17.8h, v28.8h, v17.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n"
+
+ // At this point, there will be one of 2 width or 1 width leftover,
+ // not both.
+ "cmp w14, #2\n"
+ "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n"
+
+ // Handle last two horizontal outputs if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n"
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "ld1 {v18.8b}, [x12], %[input_depth]\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "ld1 {v19.8b}, [x12]\n"
+ "smlal v26.4s, v0.4h, v11.4h\n"
+ "ld1 {v20.8b}, [x13], %[input_depth]\n"
+ "smlal2 v27.4s, v0.8h, v11.8h\n"
+ "ld1 {v21.8b}, [x13]\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "ld1 {v22.8b}, [x15], %[input_depth]\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "ld1 {v23.8b}, [x15]\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "smlal v26.4s, v3.4h, v14.4h\n"
+ "smlal2 v27.4s, v3.8h, v14.8h\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "smlal v26.4s, v6.4h, v17.4h\n"
+ "smlal2 v27.4s, v6.8h, v17.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "uaddw v18.8h, v28.8h, v18.8b\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+ "uaddw v19.8h, v28.8h, v19.8b\n"
+
+ "smlal v26.4s, v1.4h, v18.4h\n"
+ "uaddw v20.8h, v28.8h, v20.8b\n"
+ "smlal2 v27.4s, v1.8h, v18.8h\n"
+ "smlal v26.4s, v2.4h, v19.4h\n"
+ "uaddw v21.8h, v28.8h, v21.8b\n"
+ "smlal2 v27.4s, v2.8h, v19.8h\n"
+ "smlal v26.4s, v4.4h, v20.4h\n"
+ "smlal v26.4s, v5.4h, v21.4h\n"
+ "smlal2 v27.4s, v4.8h, v20.8h\n"
+ "uaddw v22.8h, v28.8h, v22.8b\n"
+ "smlal2 v27.4s, v5.8h, v21.8h\n"
+ "uaddw v23.8h, v28.8h, v23.8b\n"
+ "smlal v26.4s, v7.4h, v22.4h\n"
+ "smlal2 v27.4s, v7.8h, v22.8h\n"
+ "smlal v26.4s, v8.4h, v23.4h\n"
+ "smlal2 v27.4s, v8.8h, v23.8h\n"
+
+ "dup v28.4s, w1\n"
+ "dup v29.4s, w9\n"
+ "sqrdmulh v24.4s, v24.4s, v28.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v28.4s\n"
+ "sqrdmulh v26.4s, v26.4s, v28.4s\n"
+ "sqrdmulh v27.4s, v27.4s, v28.4s\n"
+ "dup v28.4s, w2\n"
+ "and v30.16b, v24.16b, v29.16b\n"
+ "and v31.16b, v25.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v30.4s\n"
+ "sqadd v25.4s, v25.4s, v31.4s\n"
+ "and v30.16b, v26.16b, v29.16b\n"
+ "and v31.16b, v27.16b, v29.16b\n"
+ "sshr v30.4s, v30.4s, #31\n"
+ "sshr v31.4s, v31.4s, #31\n"
+ "sqadd v26.4s, v26.4s, v30.4s\n"
+ "dup v30.4s, w3\n"
+ "sqadd v27.4s, v27.4s, v31.4s\n"
+ "dup v31.4s, w4\n"
+ "srshl v24.4s, v24.4s, v29.4s\n"
+ "srshl v25.4s, v25.4s, v29.4s\n"
+ "srshl v26.4s, v26.4s, v29.4s\n"
+ "srshl v27.4s, v27.4s, v29.4s\n"
+ "add v24.4s, v24.4s, v28.4s\n"
+ "add v25.4s, v25.4s, v28.4s\n"
+ "add v26.4s, v26.4s, v28.4s\n"
+ "add v27.4s, v27.4s, v28.4s\n"
+ "dup v28.8h, w0\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smax v26.4s, v26.4s, v30.4s\n"
+ "smax v27.4s, v27.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "smin v26.4s, v26.4s, v31.4s\n"
+ "smin v27.4s, v27.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn v26.4h, v26.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "sqxtn2 v26.8h, v27.4s\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "sqxtun v26.8b, v26.8h\n"
+ "st1 {v24.8b}, [x6], x5\n"
+ "st1 {v26.8b}, [x6]\n"
+ "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n"
+
+ // Handle bottom right output if exists.
+ DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n"
+ "dup v26.4s, w9\n"
+ "dup v27.4s, w1\n"
+ "dup v29.4s, w2\n"
+
+ "smlal v24.4s, v0.4h, v9.4h\n"
+ "smlal2 v25.4s, v0.8h, v9.8h\n"
+ "smlal v24.4s, v1.4h, v10.4h\n"
+ "smlal2 v25.4s, v1.8h, v10.8h\n"
+ "smlal v24.4s, v2.4h, v11.4h\n"
+ "smlal2 v25.4s, v2.8h, v11.8h\n"
+ "smlal v24.4s, v3.4h, v12.4h\n"
+ "smlal2 v25.4s, v3.8h, v12.8h\n"
+ "smlal v24.4s, v4.4h, v13.4h\n"
+ "smlal2 v25.4s, v4.8h, v13.8h\n"
+ "smlal v24.4s, v5.4h, v14.4h\n"
+ "smlal2 v25.4s, v5.8h, v14.8h\n"
+ "smlal v24.4s, v6.4h, v15.4h\n"
+ "smlal2 v25.4s, v6.8h, v15.8h\n"
+ "smlal v24.4s, v7.4h, v16.4h\n"
+ "smlal2 v25.4s, v7.8h, v16.8h\n"
+ "smlal v24.4s, v8.4h, v17.4h\n"
+ "smlal2 v25.4s, v8.8h, v17.8h\n"
+
+ "sqrdmulh v24.4s, v24.4s, v27.4s\n"
+ "sqrdmulh v25.4s, v25.4s, v27.4s\n"
+ "and v18.16b, v24.16b, v26.16b\n"
+ "and v19.16b, v25.16b, v26.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v24.4s, v24.4s, v18.4s\n"
+ "sqadd v25.4s, v25.4s, v19.4s\n"
+ "srshl v24.4s, v24.4s, v26.4s\n"
+ "srshl v25.4s, v25.4s, v26.4s\n"
+ "add v24.4s, v24.4s, v29.4s\n"
+ "add v25.4s, v25.4s, v29.4s\n"
+ "smax v24.4s, v24.4s, v30.4s\n"
+ "smax v25.4s, v25.4s, v30.4s\n"
+ "smin v24.4s, v24.4s, v31.4s\n"
+ "smin v25.4s, v25.4s, v31.4s\n"
+ "sqxtn v24.4h, v24.4s\n"
+ "sqxtn2 v24.8h, v25.4s\n"
+ "sqxtun v24.8b, v24.8h\n"
+ "st1 {v24.8b}, [x6]\n"
+
+ DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr),
+ [output_window_height] "+r"(output_window_height)
+ :
+ // Inputs.
+ [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size),
+ [input_depth] "r"(input_depth),
+ [output_window_width] "r"(output_window_width),
+ [input_width_increment] "r"(input_width_increment),
+ [input_height_increment] "r"(input_height_increment),
+ [output_height_increment] "r"(output_height_increment),
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
+ "x9", "x10", "x11", "x12", "x13", "x14", "x15",
+ "x19", "x20");
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER
+#undef DEPTHWISECONV_LABEL_HEIGHT_1_END
}
};
-template <>
-struct ConvKernel3x3FilterDepth8<1, 4, 2, 2> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
- uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7,
- temp_8;
-
+#undef OFFSET_INPUT_DEPTH
+#undef OFFSET_INPUT_ROW_SIZE
+#undef OFFSET_OUTPUT_DEPTH
+#undef OFFSET_OUTPUT_ROW_SIZE
+#undef OFFSET_INPUT_OFFSET
+#undef OFFSET_OUTPUT_OFFSET
+#undef OFFSET_FILTER_OFFSET
+#undef OFFSET_OUTPUT_MULTIPLIER
+#undef OFFSET_OUTPUT_ACTIVATION_MIN
+#undef OFFSET_OUTPUT_ACTIVATION_MAX
+#undef OFFSET_OUTPUT_SHIFT
+#undef OFFSET_INPUT_WIDTH
+#undef OFFSET_INPUT_HEIGHT
+#undef OFFSET_OUTPUT_WIDTH
+#undef OFFSET_OUTPUT_HEIGHT
+#undef STR
+#undef STR_UNEXPANDED
+
+// Copies a subset of the input designated by |input_ptr| into |output_ptr|
+// with the specified output dimensions. Supports output depths of 64 only as
+// this is the cache line size.
+inline void ShuffleInput(const uint8* input_ptr, int64_t input_depth,
+ int32 input_width, int32 input_height,
+ int64_t output_depth, int32 output_width,
+ int32 output_height, uint8* output_ptr) {
+ const int64_t input_row_size = input_depth * input_width;
+ for (int32 y = 0; y < output_height; y++) {
const uint8* ptr = input_ptr;
-
- // Load all inputs for top output.
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- temp_2 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- temp_5 = vld1_u8(ptr + 2 * input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
- temp_8 = vld1_u8(ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Second output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 3 * input_depth;
- temp_0 = vld1_u8(ptr);
- temp_1 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_3 = vld1_u8(ptr);
- temp_4 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_6 = vld1_u8(ptr);
- temp_7 = vld1_u8(ptr + input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
-
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
-
- DotProductAndStore(
- filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8,
- input_6, input_7, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Third output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 5 * input_depth;
- temp_2 = vld1_u8(ptr);
- temp_0 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_5 = vld1_u8(ptr);
- temp_3 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_8 = vld1_u8(ptr);
- temp_6 = vld1_u8(ptr + input_depth);
-
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
-
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
-
- DotProductAndStore(
- filter, input_1, input_2, input_0, input_4, input_5, input_3, input_7,
- input_8, input_6, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
-
- // Fourth output.
- output_ptr += output_depth;
-
- ptr = input_ptr + 7 * input_depth;
- temp_1 = vld1_u8(ptr);
- temp_2 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_4 = vld1_u8(ptr);
- temp_5 = vld1_u8(ptr + input_depth);
- ptr += input_row_size;
- temp_7 = vld1_u8(ptr);
- temp_8 = vld1_u8(ptr + input_depth);
-
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
- }
-};
-
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_ptr, int input_depth,
- int32 input_offset, int input_row_size,
- const uint8* filter_ptr, int32 filter_offset,
- const int32* bias_ptr, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_ptr,
- int output_depth, int output_width) {
- Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth);
-
- int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8;
-
- uint8x8_t temp_0 = vld1_u8(input_ptr);
- uint8x8_t temp_1 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_2 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_ptr += input_row_size;
- uint8x8_t temp_3 = vld1_u8(input_ptr);
- uint8x8_t temp_4 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_5 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_ptr += input_row_size;
- uint8x8_t temp_6 = vld1_u8(input_ptr);
- uint8x8_t temp_7 = vld1_u8(input_ptr + input_depth);
- uint8x8_t temp_8 = vld1_u8(input_ptr + 2 * input_depth);
-
- input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0));
- input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1));
- input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2));
- input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3));
- input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4));
- input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5));
- input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6));
- input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7));
- input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8));
-
- const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
- input_0 = vaddq_s16(input_0, input_offset_vec);
- input_1 = vaddq_s16(input_1, input_offset_vec);
- input_2 = vaddq_s16(input_2, input_offset_vec);
- input_3 = vaddq_s16(input_3, input_offset_vec);
- input_4 = vaddq_s16(input_4, input_offset_vec);
- input_5 = vaddq_s16(input_5, input_offset_vec);
- input_6 = vaddq_s16(input_6, input_offset_vec);
- input_7 = vaddq_s16(input_7, input_offset_vec);
- input_8 = vaddq_s16(input_8, input_offset_vec);
-
- DotProductAndStore(
- filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
- input_7, input_8, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_ptr);
- }
-};
-
-inline void ShuffleInput(const uint8* input_ptr, int input_depth,
- int input_width, int input_height, int output_depth,
- int output_width, int output_height,
- uint8* output_ptr) {
- const int input_row_size = input_depth * input_width;
-
- for (int y = 0; y < output_height; y++) {
- const uint8* ptr = input_ptr;
- for (int x = 0; x < output_width; x++) {
+ for (int32 x = 0; x < output_width; x++) {
memcpy(output_ptr, ptr, output_depth);
output_ptr += output_depth;
ptr += input_depth;
@@ -3873,561 +2216,162 @@ inline void ShuffleInput(const uint8* input_ptr, int input_depth,
}
}
-template <int kFixedHeight, int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8 {};
-
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8<1, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 1x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<1, 4, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * kFixedStrideWidth * input_depth;
- output_data += 4 * output_depth;
- }
-
- // 1x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
+// Calculates the input size depending on stride and output.
+inline int32 get_shuffle_input_size(int32 stride, int32 output) {
+ return stride * (output - 1) + 3;
+}
- input_data += kFixedStrideWidth * input_depth;
- output_data += output_depth;
- }
+// Indicates the input and output dimensions used when shuffling input
+// activations.
+struct ShuffleParams {
+ int32 output_width;
+ int32 output_height;
+ int32 input_width;
+ int32 input_height;
+
+ ShuffleParams() = default;
+ ShuffleParams(int32 output_width, int32 output_height, int32 stride_width,
+ int32 stride_height)
+ : output_width(output_width)
+ , output_height(output_height)
+ , input_width(get_shuffle_input_size(stride_width, output_width))
+ , input_height(get_shuffle_input_size(stride_height, output_height)) {
}
};
-template <int kFixedStrideWidth, int kFixedStrideHeight>
-struct ConvRow3x3FilterDepth8<2, kFixedStrideWidth, kFixedStrideHeight> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 2x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 4, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * kFixedStrideWidth * input_depth;
- output_data += 4 * output_depth;
- }
-
- // 2x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 2, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 2 * kFixedStrideWidth * input_depth;
- output_data += 2 * output_depth;
- }
-
- // 2x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<2, 1, kFixedStrideWidth, kFixedStrideHeight>::
- Run(input_ptr, input_depth, input_offset, input_row_size,
- filter_ptr, filter_offset, bias_ptr, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += kFixedStrideWidth * input_depth;
- output_data += output_depth;
+template <int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvThroughDepth {
+ // Runs the DepthwiseConvWindow kernels through the depth dimension from
+ // |start_depth| to |end_depth|. Keep this not inlined to maintain a small
+ // binary size. We use a DepthwiseConvParams struct for read only params
+ // to minimize call overhead.
+ static __attribute__((noinline)) void Run(const uint8* input_ptr,
+ const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr,
+ int64_t start_depth, int64_t end_depth, int64_t input_depth,
+ int64_t input_row_size, int32 output_window_height,
+ int32 output_window_width, const DepthwiseConvParams& params) {
+ for (; start_depth <= end_depth - 8; start_depth += 8) {
+ DepthwiseConvWindow<8, kStrideWidth, kStrideHeight>::Run(
+ input_ptr, filter_ptr, bias_ptr, output_ptr, input_depth,
+ input_row_size, output_window_height, output_window_width, &params);
+ input_ptr += 8;
+ output_ptr += 8;
+ filter_ptr += 8;
+ bias_ptr += 8;
}
}
};
-template <>
-struct ConvRow3x3FilterDepth8<4, 1, 1> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
-
- // 4x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * input_depth;
- output_data += 4 * output_depth;
- }
-
- // Handle the rest of the right side.
- // 4x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 2, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 2 * input_depth;
- output_data += 2 * output_depth;
- }
-
- // 4x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 1, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += input_depth;
- output_data += output_depth;
- }
- }
-};
+template <int32 kStrideWidth, int32 kStrideHeight>
+struct DepthwiseConvMultiRow {
+ using ConvKernel = DepthwiseConvThroughDepth<kStrideWidth, kStrideHeight>;
-template <>
-struct ConvRow3x3FilterDepth8<4, 2, 2> {
- // The buffer size of the shuffled input.
- static inline constexpr int ShuffleWorkspaceSize() { return 64 * 9 * 9; }
-
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
+ static inline void Run(const uint8* input_data, int32 start_x, int32 start_y,
+ const uint8* filter_data, const int32* bias_data,
+ uint8* output_data, const DepthwiseConvParams& params,
+ const ShuffleParams& shuffle_params,
uint8* shuffle_workspace) {
- // Branch and cache misses increase substantially with stride 2 kernels.
- // Adding prefetching reduces latency by as much as 2x.
- const int i0 = 0;
- const int i1 = input_depth;
- const int i2 = 2 * input_depth;
- const int i3 = 3 * input_depth;
- const int i4 = 4 * input_depth;
- const int i5 = 5 * input_depth;
- const int i6 = 6 * input_depth;
- const int i7 = 7 * input_depth;
- const int i8 = 8 * input_depth;
-
-#define DEPTHWISECONV_PRELOAD_ROW(input_ptr, i) \
- preload_l1_keep(input_ptr + i * input_row_size + i0); \
- preload_l1_keep(input_ptr + i * input_row_size + i1); \
- preload_l1_keep(input_ptr + i * input_row_size + i2); \
- preload_l1_keep(input_ptr + i * input_row_size + i3); \
- preload_l1_keep(input_ptr + i * input_row_size + i4); \
- preload_l1_keep(input_ptr + i * input_row_size + i5); \
- preload_l1_keep(input_ptr + i * input_row_size + i6); \
- preload_l1_keep(input_ptr + i * input_row_size + i7); \
- preload_l1_keep(input_ptr + i * input_row_size + i8);
-
- int out_x = start_x;
- // 4x4 at a time.
- for (; out_x <= output_width - 4; out_x += 4) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- int depth = 0;
- for (; depth <= output_depth - 64; depth += 64) {
- // Preload 9x9 input.
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8);
-
- // For a large input window (64x9x9) that is small enough to fit in L1
- // cache, copy the input into a separate buffer and run the kernel on
- // this new buffer. This reduces the likelihood of cache misses when
- // the kernel is loading input data. If this size is ever changed,
- // update the ShuffleWorkspaceSize() function to return the new size.
- ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 9,
- 9, shuffle_workspace);
- const uint8* shuffled_ptr = &shuffle_workspace[0];
-
- for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run(
- shuffled_ptr, 64, input_offset, 64 * 9, filter_ptr, filter_offset,
- bias_ptr, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_ptr,
- output_depth, output_width);
-
- shuffled_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
+ TFLITE_DCHECK(shuffle_params.input_height ==
+ get_shuffle_input_size(kStrideHeight, shuffle_params.output_height));
+ TFLITE_DCHECK(shuffle_params.input_width ==
+ get_shuffle_input_size(kStrideWidth, shuffle_params.output_width));
+ TFLITE_DCHECK(64 * shuffle_params.input_width * shuffle_params.input_height
+ <= DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE);
+
+ int32 out_x = start_x;
+
+ // Run shuffling on inputs with sufficiently large depth and width. When
+ // these parameters are large enough, more time is taken to load inputs
+ // from memory. At this point, it becomes useful to prefetch and
+ // preshuffle the input data to maximize locality.
+ if (params.output_depth > 64 ||
+ (params.output_depth <= 64 && params.input_width > 150)) {
+ for (; out_x <= (params.output_width - shuffle_params.output_width);
+ out_x += shuffle_params.output_width) {
+ const uint8* input_ptr = input_data;
+ const int32* bias_ptr = bias_data;
+ const uint8* filter_ptr = filter_data;
+ uint8* output_ptr = output_data;
+ int64_t depth = 0;
+ const int64_t shuffle_row_size = 64 * shuffle_params.input_width;
+
+ for (; depth <= params.output_depth - 64; depth += 64) {
+ // Preload.
+ const uint8* h_ptr = input_ptr;
+ for (int32 i = 0; i < shuffle_params.input_height; i++) {
+ const uint8* ptr = h_ptr;
+ for (int32 j = 0; j < shuffle_params.input_width; j++) {
+ asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+ ptr += params.input_depth;
+ }
+ h_ptr += params.input_row_size;
+ }
+
+ // For a large enough input, shuffle into buckets.
+ ShuffleInput(input_ptr, params.input_depth, params.input_width,
+ params.input_height, 64, shuffle_params.input_width,
+ shuffle_params.input_height, shuffle_workspace);
+ ConvKernel::Run(shuffle_workspace, filter_ptr, bias_ptr, output_ptr,
+ 0, 64, 64, shuffle_row_size,
+ shuffle_params.output_height,
+ shuffle_params.output_width, params);
+ input_ptr += 64;
+ output_ptr += 64;
+ filter_ptr += 64;
+ bias_ptr += 64;
}
- input_ptr += 64;
- }
- // Preload 9x9 input one more time for the rest of the depth.
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7);
- DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8);
-
- for (; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
-
- input_data += 4 * 2 * input_depth;
- output_data += 4 * output_depth;
- }
-
-#undef DEPTHWISECONV_PRELOAD_ROW
-
- // Handle the rest of the right side.
- // 4x2 at a time.
- for (; out_x <= output_width - 2; out_x += 2) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
+ // Preload.
+ const uint8* h_ptr = input_ptr;
+ for (int32 i = 0; i < shuffle_params.input_height; i++) {
+ const uint8* ptr = h_ptr;
+ for (int32 j = 0; j < shuffle_params.input_width; j++) {
+ asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
+ ptr += params.input_depth;
+ }
+ h_ptr += params.input_row_size;
+ }
- input_data += 2 * 2 * input_depth;
- output_data += 2 * output_depth;
- }
+ // Handle leftover depth.
+ ConvKernel::Run(input_ptr, filter_ptr, bias_ptr, output_ptr,
+ depth, params.output_depth, params.input_depth,
+ params.input_row_size, shuffle_params.output_height,
+ shuffle_params.output_width, params);
- // 4x1 at a time.
- for (; out_x < output_width; out_x++) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- for (int depth = 0; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<4, 1, 2, 2>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
+ input_data +=
+ shuffle_params.output_width * kStrideWidth * params.input_depth;
+ output_data += shuffle_params.output_width * params.output_depth;
}
-
- input_data += 2 * input_depth;
- output_data += output_depth;
}
- }
-};
-
-template <>
-struct ConvRow3x3FilterDepth8<8, 2, 2> {
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- // Reuse 4 row kernels twice.
- ConvRow3x3FilterDepth8<4, 2, 2>::Run(
- input_data, start_x, start_y, input_depth, input_width, input_height,
- input_row_size, input_offset, filter_data, filter_offset, bias_data,
- output_offset, output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_depth, output_width,
- shuffle_workspace);
-
- ConvRow3x3FilterDepth8<4, 2, 2>::Run(
- input_data + 2 * 4 * input_row_size, start_x, start_y + 4, input_depth,
- input_width, input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data + 4 * output_depth * output_width, output_depth,
- output_width, shuffle_workspace);
- }
-};
-
-template <>
-struct ConvRow3x3FilterDepth8<8, 1, 1> {
- // The buffer size of the shuffled input.
- static inline constexpr int ShuffleWorkspaceSize() { return 64 * 10 * 10; }
-
- static inline void Run(const uint8* input_data, int start_x, int start_y,
- int input_depth, int input_width, int input_height,
- int input_row_size, int32 input_offset,
- const uint8* filter_data, int32 filter_offset,
- const int32* bias_data, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- int output_depth, int output_width,
- uint8* shuffle_workspace) {
- int out_x = start_x;
- // 8x8 at a time.
- for (; out_x <= output_width - 8; out_x += 8) {
- const int32* bias_ptr = bias_data;
- const uint8* filter_ptr = filter_data;
-
- const uint8* input_ptr = input_data;
- uint8* output_ptr = output_data;
-
- int depth = 0;
- for (; depth <= output_depth - 64; depth += 64) {
- // For a large input window (64x10x10) that is small enough to fit in L1
- // cache, copy the input into a separate buffer and run the kernel on
- // this new buffer. This reduces the likelihood of cache misses when
- // the kernel is loading input data. If the size of the input window
- // changes, update the function ShuffleWorkspaceSize() with the new
- // size.
- ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 10,
- 10, shuffle_workspace);
- const uint8* shuffled_ptr = shuffle_workspace;
-
- for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) {
- ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run(
- shuffled_ptr, 64, input_offset, 64 * 10, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- shuffled_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
- input_ptr += 64;
- }
-
- for (; depth <= output_depth - 8; depth += 8) {
- ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run(
- input_ptr, input_depth, input_offset, input_row_size, filter_ptr,
- filter_offset, bias_ptr, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_ptr, output_depth, output_width);
-
- input_ptr += 8;
- output_ptr += 8;
- filter_ptr += 8;
- bias_ptr += 8;
- }
- input_data += 8 * input_depth;
- output_data += 8 * output_depth;
+ const int32 output_leftover_width = params.output_width - out_x;
+ if (output_leftover_width > 0) {
+ ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0,
+ params.output_depth, params.input_depth,
+ params.input_row_size, shuffle_params.output_height,
+ output_leftover_width, params);
}
-
- // Handle the rest of the right side by re-using 4 row kernels twice.
- ConvRow3x3FilterDepth8<4, 1, 1>::Run(
- input_data, out_x, start_y, input_depth, input_width, input_height,
- input_row_size, input_offset, filter_data, filter_offset, bias_data,
- output_offset, output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_depth, output_width,
- shuffle_workspace);
-
- ConvRow3x3FilterDepth8<4, 1, 1>::Run(
- input_data + 4 * input_row_size, out_x, start_y + 4, input_depth,
- input_width, input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max,
- output_data + 4 * output_depth * output_width, output_depth,
- output_width, shuffle_workspace);
}
};
-inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims,
- const Dims<4>& filter_dims,
- int stride_width, int stride_height,
- int pad_width, int pad_height,
- int depth_multiplier,
- const Dims<4>& output_dims) {
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
-
- bool supported = filter_width == 3 && filter_height == 3 &&
- depth_multiplier == 1 &&
- (stride_width == 1 || stride_width == 2) &&
- (stride_height == 1 || stride_height == 2) &&
- (stride_width == stride_height) && pad_width == 0 &&
- pad_height == 0 && (input_depth % 8) == 0;
+inline bool Fast3x3FilterKernelSupported(
+ const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width,
+ int32 stride_height, int32 pad_width, int32 pad_height,
+ int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) {
+ const int32 input_height = ArraySize(input_dims, 2);
+ const int32 input_width = ArraySize(input_dims, 1);
+ const int32 input_depth = ArraySize(input_dims, 0);
+ const int32 filter_height = ArraySize(filter_dims, 2);
+ const int32 filter_width = ArraySize(filter_dims, 1);
+ const int32 output_height = ArraySize(output_dims, 2);
+ const int32 output_width = ArraySize(output_dims, 1);
+
+ bool supported =
+ filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
+ (stride_width == 1 || stride_width == 2) &&
+ (stride_height == 1 || stride_height == 2) &&
+ (stride_width == stride_height) && pad_width == 0 && pad_height == 0 &&
+ (input_depth % 8) == 0 && (output_shift > 0);
if (!supported) {
return false;
@@ -4436,14 +2380,14 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims,
// Handle case where padding is zero but padding type is not kValid.
// This would require special boundary case handling that is not supported.
- const int out_x = output_width - 1;
- const int out_y = output_height - 1;
+ const int32 out_x = output_width - 1;
+ const int32 out_y = output_height - 1;
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int32 in_x_origin = (out_x * stride_width) - pad_width;
+ const int32 in_y_origin = (out_y * stride_height) - pad_height;
- const int in_x_end = in_x_origin + filter_width;
- const int in_y_end = in_y_origin + filter_height;
+ const int32 in_x_end = in_x_origin + filter_width;
+ const int32 in_y_end = in_y_origin + filter_height;
// Supported only if filter on the right and bottom boundary lies completely
// within the input.
@@ -4453,128 +2397,135 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims,
inline void DepthwiseConv3x3Filter(
const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
-
- // Algorithm assumes below constraints. It is optimized for depth multiplier
- // of 1, 3x3 filter, no padding and strides 1 and 2.
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width,
+ int32 stride_height, int32 pad_width, int32 pad_height,
+ int32 depth_multiplier, int32 output_offset, int32 output_multiplier,
+ int32 output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConvParams params;
+ params.input_depth = ArraySize(input_dims, 0);
+ params.input_width = ArraySize(input_dims, 1);
+ params.input_height = ArraySize(input_dims, 2);
+ params.input_row_size = params.input_depth * params.input_width;
+ params.input_offset = input_offset;
+ params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ params.output_width = ArraySize(output_dims, 1);
+ params.output_height = ArraySize(output_dims, 2);
+ params.output_row_size = params.output_depth * params.output_width;
+ params.output_offset = output_offset;
+ params.filter_offset = filter_offset;
+ params.output_multiplier = output_multiplier;
+ params.output_shift = output_shift;
+ params.output_activation_min = output_activation_min;
+ params.output_activation_max = output_activation_max;
+
+ const int32 filter_height = ArraySize(filter_dims, 2);
+ const int32 filter_width = ArraySize(filter_dims, 1);
+
+ // Algorithm assumes below constraints. It is optimized for depth
+ // multiplier of 1, 3x3 filter, no padding and strides 1 and 2.
+ TFLITE_DCHECK(params.output_depth == params.input_depth * depth_multiplier);
TFLITE_DCHECK(depth_multiplier == 1);
TFLITE_DCHECK(filter_height == 3);
TFLITE_DCHECK(filter_width == 3);
- TFLITE_DCHECK(pad_height == 0);
- TFLITE_DCHECK(pad_width == 0);
TFLITE_DCHECK(stride_height == 1 || stride_height == 2);
TFLITE_DCHECK(stride_width == 1 || stride_width == 2);
TFLITE_DCHECK(stride_width == stride_height);
+ TFLITE_DCHECK(pad_height == 0);
+ TFLITE_DCHECK(pad_width == 0);
- const int input_row_size = input_depth * (input_width + 2 * pad_width);
- const int output_row_size = output_depth * output_width;
- const int input_batch_size = input_row_size * (input_height + 2 * pad_height);
- const int output_batch_size = output_depth * output_width * output_height;
-
- using conv_row_func_t = decltype(&ConvRow3x3FilterDepth8<1, 1, 1>::Run);
- conv_row_func_t conv_1_output_row = ConvRow3x3FilterDepth8<1, 1, 1>::Run;
- conv_row_func_t conv_2_output_rows = ConvRow3x3FilterDepth8<2, 1, 1>::Run;
- conv_row_func_t conv_4_output_rows = ConvRow3x3FilterDepth8<4, 1, 1>::Run;
- conv_row_func_t conv_8_output_rows = ConvRow3x3FilterDepth8<8, 1, 1>::Run;
+ const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int64_t input_batch_size = params.input_row_size * params.input_height;
+ const int64_t output_batch_size =
+ params.output_row_size * params.output_height;
+
+ ShuffleParams one_row_shuffle_params, two_row_shuffle_params,
+ four_row_shuffle_params, eight_row_shuffle_params;
+ if (stride_width == 1) {
+ one_row_shuffle_params = ShuffleParams(30, 1, 1, 1);
+ two_row_shuffle_params = ShuffleParams(22, 2, 1, 1);
+ four_row_shuffle_params = ShuffleParams(14, 4, 1, 1);
+ eight_row_shuffle_params = ShuffleParams(8, 8, 1, 1);
+ } else {
+ one_row_shuffle_params = ShuffleParams(14, 1, 2, 2);
+ two_row_shuffle_params = ShuffleParams(8, 2, 2, 2);
+ four_row_shuffle_params = ShuffleParams(4, 4, 2, 2);
+ eight_row_shuffle_params = ShuffleParams(2, 8, 2, 2);
+ }
+ using conv_multirow_func_t = decltype(&DepthwiseConvMultiRow<1, 1>::Run);
+ conv_multirow_func_t conv_multirow_func = DepthwiseConvMultiRow<1, 1>::Run;
if (stride_width == 2) {
- conv_1_output_row = ConvRow3x3FilterDepth8<1, 2, 2>::Run;
- conv_2_output_rows = ConvRow3x3FilterDepth8<2, 2, 2>::Run;
- conv_4_output_rows = ConvRow3x3FilterDepth8<4, 2, 2>::Run;
- conv_8_output_rows = ConvRow3x3FilterDepth8<8, 2, 2>::Run;
+ conv_multirow_func = DepthwiseConvMultiRow<2, 2>::Run;
}
// Allocate maximum memory needed for shuffled input.
// TODO(mariewhite): The size of this workspace is small enough to be
// allocated on the stack. Eventually we will want to move it to the heap
- // and have it allocated outside of this function, like the im2col_array used
- // in gemmlowp.
-#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64
+ // and have it allocated outside of this function, like the im2col_array
+ // used in gemmlowp.
uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE];
- // Make sure the kernels using this buffer will not run out of bounds.
- static_assert(ConvRow3x3FilterDepth8<8, 1, 1>::ShuffleWorkspaceSize() <=
- DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE,
- "Shuffle workspace size is too small.");
- static_assert(ConvRow3x3FilterDepth8<4, 2, 2>::ShuffleWorkspaceSize() <=
- DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE,
- "Shuffle workspace size is too small.");
-
-#undef DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE
-
- for (int b = 0; b < batches; ++b) {
+ for (int32 b = 0; b < batches; ++b) {
const uint8* input_ptr = input_data + b * input_batch_size;
uint8* output_ptr = output_data + b * output_batch_size;
- int out_y = 0;
+ int32 out_y = 0;
+
+ // Shuffling shapes that maximize width over the shuffle workspace size
+ // perform better since the inputs are closer together, minimizing
+ // shuffling time.
+ //
+ // If the input shape has width large enough for the 2 row kernels,
+ // we prefer to use this. The innermost loop of the kernels handle
+ // 2 height x 2 width so this is the fastest path.
+ //
+ // If the input shape has smaller width but larger height, shuffling is
+ // still useful and can benefit from kernels 4 row and 8 row kernels.
// Handle 8 rows at a time.
- for (; out_y <= output_height - 8; out_y += 8) {
- conv_8_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 8 * stride_height * input_row_size;
- output_ptr += 8 * output_row_size;
+ if (params.input_width < four_row_shuffle_params.input_width) {
+ for (; out_y <= params.output_height - 8; out_y += 8) {
+ conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ output_ptr, params, eight_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 8 * stride_height * params.input_row_size;
+ output_ptr += 8 * params.output_row_size;
+ }
}
// Handle 4 rows at a time.
- for (; out_y <= output_height - 4; out_y += 4) {
- conv_4_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 4 * stride_height * input_row_size;
- output_ptr += 4 * output_row_size;
+ if (params.input_width < two_row_shuffle_params.input_width) {
+ for (; out_y <= params.output_height - 4; out_y += 4) {
+ conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ output_ptr, params, four_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 4 * stride_height * params.input_row_size;
+ output_ptr += 4 * params.output_row_size;
+ }
}
// Handle 2 rows at a time.
- for (; out_y <= output_height - 2; out_y += 2) {
- conv_2_output_rows(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset,
- filter_data, filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += 2 * stride_height * input_row_size;
- output_ptr += 2 * output_row_size;
+ for (; out_y <= params.output_height - 2; out_y += 2) {
+ conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ output_ptr, params, two_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += 2 * stride_height * params.input_row_size;
+ output_ptr += 2 * params.output_row_size;
}
// Handle one row at a time.
- for (; out_y < output_height; out_y++) {
- conv_1_output_row(input_ptr, 0, out_y, input_depth, input_width,
- input_height, input_row_size, input_offset, filter_data,
- filter_offset, bias_data, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_ptr, output_depth,
- output_width, shuffle_workspace);
-
- input_ptr += stride_height * input_row_size;
- output_ptr += output_row_size;
+ for (; out_y < params.output_height; out_y++) {
+ conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ output_ptr, params, one_row_shuffle_params,
+ shuffle_workspace);
+ input_ptr += stride_height * params.input_row_size;
+ output_ptr += params.output_row_size;
}
}
}
+// clang-format on
#endif // __aarch64__
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 08f7cfa5a5..38ad32c734 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -352,6 +352,30 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) {
}
}
+bool NeonIsZeroVector(const float* vector, int v_size) {
+ // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot
+ // use the main vectorized loop, and we need to process sequentially.
+ // postamble_start shows the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ const float32x4_t zero_x4_float = vmovq_n_f32(0.0f);
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ const float32x4_t i_x4_float = vld1q_f32(vector + v);
+ uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float);
+ if (vgetq_lane_u32(cmp_result, 0) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 1) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 2) == 0) return false;
+ if (vgetq_lane_u32(cmp_result, 3) == 0) return false;
+ }
+
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; ++v) {
+ if (vector[v] != 0.0) return false;
+ }
+ return true;
+}
+
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
// If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 9e60d0657b..7a5a8fc541 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -100,6 +100,11 @@ void ZeroVector(float* vector, int v_size) {
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+// Check if all entries of a vector are zero.
+bool IsZeroVector(const float* vector, int v_size) {
+ return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
+}
+
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 64ba5e62f6..d48178d608 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -140,6 +140,45 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. The template version is to be preferred,
+// since some target hardware optimizations depend on the range of the exponent.
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+ if (exponent == 0) {
+ return x;
+ }
+ using ScalarIntegerType =
+ typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ const IntegerType min =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+ const IntegerType max =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+ const IntegerType positive_mask =
+ gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+ const IntegerType negative_mask =
+ gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+ IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+ result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+ result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+ return result;
+}
+
+// This is like the template-parameter version, except that the power-of-two is
+// passed as a function parameter. See raw-integer version for further comments.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
+}
+
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
// BROADCASTING.
//
@@ -1979,11 +2018,23 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
const int gemm_input_rows = gemm_input_dims->sizes[0];
- const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
+ // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
+ // The root cause has not yet been identified though. Same applies below for
+ // the other calls commented out. This is a partial rollback of cl/196819423.
+ // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
+ const int gemm_input_cols = gemm_input_dims->sizes[1] *
+ gemm_input_dims->sizes[2] *
+ gemm_input_dims->sizes[3];
const int filter_rows = filter_dims.sizes[3];
- const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ // See b/79927784.
+ // const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
const int output_rows = output_dims.sizes[0];
- const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ // See b/79927784.
+ // const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ const int output_cols =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
TFLITE_DCHECK_EQ(output_rows, filter_rows);
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
@@ -2353,24 +2404,27 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- TFLITE_DCHECK_EQ(outer_size, 1);
- int32 square_l2_norm = 0;
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[i] - input_zero_point;
- square_l2_norm += diff * diff;
- }
- int32 inv_l2norm_multiplier;
- int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
-
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[i] - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
- int32 unclamped_output_val = 128 + rescaled_diff;
- int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[i] = static_cast<uint8>(output_val);
+ for (int i = 0; i < outer_size; ++i) {
+ int32 square_l2_norm = 0;
+ for (int c = 0; c < depth; c++) {
+ int32 diff = input_data[c] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
+
+ for (int c = 0; c < depth; c++) {
+ int32 diff = *input_data - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ *output_data = static_cast<uint8>(output_val);
+ ++input_data;
+ ++output_data;
+ }
}
}
@@ -4556,6 +4610,119 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
}
}
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ // assert(__builtin_clz(0u) >= std::numeric_limits<uint32>::digits - 1);
+ // assert(__builtin_clz(0u) <= std::numeric_limits<uint32>::digits);
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ // The reason for accumulating the result with an extra bit of headroom is
+ // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+ // recip_denom will otherwise introduce an error.
+ static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+ const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1488522236, std::log(2.0));
+ const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+ const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1518500250, std::sqrt(0.5));
+ const FixedPoint0 one_quarter =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+ const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1057819769,
+ 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+ const FixedPointAccum shifted_quarter =
+ gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+ // Reinterpret the input value as Q0.31, because we will figure out the
+ // required shift "ourselves" instead of using, say, Rescale.
+ FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+ // z_a_pow_2 = input_integer_bits - z_a_headroom;
+ int z_a_headroom_plus_1 = __builtin_clz(static_cast<uint32>(z_a.raw()));
+ FixedPoint0 r_a_tmp =
+ SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+ const int32 r_a_raw =
+ SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+ // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+ // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+ // InputIntegerBits - z_b_headroom - 0.25);
+ const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+ FixedPoint0 z_b = z_a * sqrt_half;
+ int z_b_headroom = __builtin_clz(static_cast<uint32>(z_b.raw())) - 1;
+ const int32 r_b_raw =
+ SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+ const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+ const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+ std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+ const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+ FixedPoint0 q = r - sqrt_sqrt_half;
+ q = q + q;
+
+ const FixedPoint0 common_sq = q * q;
+ const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+ const FixedPoint0 denom_minus_one_0 =
+ p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+ const FixedPoint0 recip_denom =
+ one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+ const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+ return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+ num_scaled * recip_denom);
+}
+
+// Minimum output bits to accommodate log of maximum input range. It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+ return input_bits > 90
+ ? 7
+ : input_bits > 44
+ ? 6
+ : input_bits > 21
+ ? 5
+ : input_bits > 10
+ ? 4
+ : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ static_assert(
+ OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+ "Output integer bits must be sufficent to accommodate logs of inputs.");
+ return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+ InputIntegerBits>(
+ input_val);
+}
+
// Currently just a copy of the reference code.
inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_multiplier, int32 input_left_shift,
@@ -4601,13 +4768,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
- // TODO(b/77858996): Implement fixed-point log().
- // Not a fully-quantized implementation: floating-point log().
- const float float_log_sum_of_exps =
- std::log(static_cast<float>(sum_of_exps.raw()) /
- (1 << (31 - kAccumulationIntegerBits)));
- const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
- float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+ const int32 fixed_log_sum_of_exps =
+ log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+ sum_of_exps)
+ .raw();
// rescaled_diff_min is smallest representable in
// Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index d570dadd86..f14667090f 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -127,6 +127,10 @@ void PortableZeroVector(float* vector, int v_size);
// Limit a float input f between +abs_limit and -abs_limit.
float PortableClip(float f, float abs_limit);
+// Check if all entries of a vector are zero.
+bool PortableIsZeroVector(const float* vector, int v_size);
+bool NeonIsZeroVector(const float* vector, int v_size);
+
// Symmetric quantizer.
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index e9b6baeaee..d57739279f 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -76,8 +76,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
if (bias_data) {
acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 2607adc0c1..f8c6f341f7 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -29,9 +29,18 @@ float PortableClip(float f, float abs_limit) {
return result;
}
+bool PortableIsZeroVector(const float* vector, int v_size) {
+ for (int i = 0; i < v_size; ++i) {
+ if (*vector++ != 0.0f) return false;
+ }
+ return true;
+}
+
void PortableSymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values, float* min,
- float* max, float* scaling_factor) {
+ int8_t* quantized_values,
+ float* __restrict__ min,
+ float* __restrict__ max,
+ float* __restrict__ scaling_factor) {
auto minmax = std::minmax_element(values, values + size);
*min = *minmax.first;
*max = *minmax.second;
@@ -71,13 +80,14 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
void PortableMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
- const int8_t* __restrict__ vectors, const float* scaling_factors,
- int n_batch, float* __restrict__ result, int result_stride) {
+ const int8_t* __restrict__ vectors,
+ const float* __restrict__ scaling_factors, int n_batch,
+ float* __restrict__ result, int result_stride) {
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
// Get the address of the first row.
- int8_t* row_ptr = (int8_t*)matrix; // NOLINT
+ const int8_t* row_ptr = matrix;
for (row = 0; row < m_rows; ++row, result += result_stride) {
// Initialize the dot product sum for the row to 0.
int32_t dotprod = 0;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 1757a9f5e5..d2e1fecd25 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -25,6 +25,8 @@ namespace tensor_utils {
// Limit a float input f between +abs_limit and -abs_limit.
float PortableClip(float f, float abs_limit);
+bool PortableIsZeroVector(const float* vector, int v_size);
+
void PortableSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
float* max, float* scaling_factor);
@@ -112,6 +114,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
+bool IsZeroVector(const float* vector, int v_size) {
+ return PortableIsZeroVector(vector, v_size);
+}
+
void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min, float* max,
float* scaling_factor) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e70d8e5454..48a96f7db0 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -33,8 +33,139 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
+
+// TODO(b/77858996): Add these to gemmlowp.
+template <typename IntegerType>
+IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+template <>
+inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) {
+ std::int64_t a64 = a;
+ std::int64_t b64 = b;
+ std::int64_t sum = a64 + b64;
+ return static_cast<std::int32_t>(std::min(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+ std::max(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+ sum)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingAddNonGemmlowp(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingAddNonGemmlowp(a.raw(), b.raw()));
+}
+
+template <typename IntegerType>
+IntegerType SaturatingSub(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+template <>
+inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t diff = a32 - b32;
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, diff)));
+}
+
+template <>
+inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) {
+ std::int64_t a64 = a;
+ std::int64_t b64 = b;
+ std::int64_t diff = a64 - b64;
+ return static_cast<std::int32_t>(std::min(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::max()),
+ std::max(
+ static_cast<std::int64_t>(std::numeric_limits<std::int32_t>::min()),
+ diff)));
+}
+
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a,
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> b) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingSub(a.raw(), b.raw()));
+}
+// End section to be moved to gemmlowp.
+
namespace reference_ops {
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32 x, int32 quantized_multiplier, int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32 x, int32 quantized_multiplier, int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier);
+}
+
+template <typename T>
+int CountLeadingZeros(T integer_input) {
+ static_assert(std::is_unsigned<T>::value,
+ "Only unsigned integer types handled.");
+ if (integer_input == 0) {
+ return std::numeric_limits<T>::digits;
+ }
+ const T one_in_leading_positive = static_cast<T>(1)
+ << (std::numeric_limits<T>::digits - 1);
+ int leading_zeros = 0;
+ while (integer_input < one_in_leading_positive) {
+ integer_input <<= 1;
+ ++leading_zeros;
+ }
+ return leading_zeros;
+}
+
+template <typename IntegerType>
+IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) {
+ if (exponent == 0) {
+ return x;
+ }
+ using ScalarIntegerType =
+ typename gemmlowp::FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ const IntegerType min =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
+ const IntegerType max =
+ gemmlowp::Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
+
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1);
+ const IntegerType positive_mask =
+ gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup<IntegerType>(threshold));
+ const IntegerType negative_mask =
+ gemmlowp::MaskIfLessThan(x, gemmlowp::Dup<IntegerType>(-threshold));
+
+ IntegerType result = gemmlowp::ShiftLeft(x, exponent);
+ result = gemmlowp::SelectUsingMask(positive_mask, max, result);
+ result = gemmlowp::SelectUsingMask(negative_mask, min, result);
+ return result;
+}
+
+// If we want to leave IntegerBits fixed, then multiplication
+// by a power of two has to be saturating/rounding, not exact anymore.
+template <typename tRawType, int tIntegerBits>
+gemmlowp::FixedPoint<tRawType, tIntegerBits>
+SaturatingRoundingMultiplyByPOTParam(
+ gemmlowp::FixedPoint<tRawType, tIntegerBits> a, int exponent) {
+ return gemmlowp::FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
+}
+
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
// BROADCASTING.
//
@@ -895,25 +1026,28 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
const Dims<4>& output_dims) {
const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- TFLITE_DCHECK_EQ(outer_size, 1);
- int32 square_l2_norm = 0;
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
- square_l2_norm += diff * diff;
- }
- int32 inv_l2norm_multiplier;
- int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
+ for (int i = 0; i < outer_size; ++i) {
+ int32 square_l2_norm = 0;
+ for (int c = 0; c < depth; c++) {
+ int32 diff =
+ input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32 inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
- for (int i = 0; i < depth; i++) {
- int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
- int32 unclamped_output_val = 128 + rescaled_diff;
- int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[Offset(output_dims, i, 0, 0, 0)] =
- static_cast<uint8>(output_val);
+ for (int c = 0; c < depth; c++) {
+ int32 diff =
+ input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 unclamped_output_val = 128 + rescaled_diff;
+ int32 output_val = std::min(255, std::max(0, unclamped_output_val));
+ output_data[Offset(output_dims, c, i, 0, 0)] =
+ static_cast<uint8>(output_val);
+ }
}
}
@@ -2639,6 +2773,121 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
}
}
+// Although currently the name of this function says that it cannot handle
+// values less than 1, in practice it can handle as low as 1/x_max, where
+// x_max is the largest representable input. In other words, the output range
+// is symmetric.
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1_impl(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ // The reason for accumulating the result with an extra bit of headroom is
+ // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled *
+ // recip_denom will otherwise introduce an error.
+ static constexpr int kAccumIntegerBits = OutputIntegerBits + 1;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumIntegerBits>;
+
+ const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1488522236, std::log(2.0));
+ const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5)));
+ const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1518500250, std::sqrt(0.5));
+ const FixedPoint0 one_quarter =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0);
+
+ const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 1057819769,
+ 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0)));
+ const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(
+ FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0)));
+
+ const FixedPointAccum shifted_quarter =
+ gemmlowp::Rescale<kAccumIntegerBits>(one_quarter);
+
+ // Reinterpret the input value as Q0.31, because we will figure out the
+ // required shift "ourselves" instead of using, say, Rescale.
+ FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
+ // z_a_pow_2 = input_integer_bits - z_a_headroom;
+ int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
+ FixedPoint0 r_a_tmp =
+ SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
+ const int32 r_a_raw =
+ SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1);
+ // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25);
+ // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25,
+ // InputIntegerBits - z_b_headroom - 0.25);
+ const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ // z_b is treated like z_a, but premultiplying by sqrt(0.5).
+ FixedPoint0 z_b = z_a * sqrt_half;
+ int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
+ const int32 r_b_raw =
+ SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
+ const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
+ FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
+ InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
+ shifted_quarter);
+
+ const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
+ const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw(
+ std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw()));
+
+ const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half);
+ FixedPoint0 q = r - sqrt_sqrt_half;
+ q = q + q;
+
+ const FixedPoint0 common_sq = q * q;
+ const FixedPoint0 num = q * r + q * common_sq * alpha_n;
+ const FixedPoint0 denom_minus_one_0 =
+ p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q;
+ const FixedPoint0 recip_denom =
+ one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0);
+
+ const FixedPointAccum num_scaled = gemmlowp::Rescale<kAccumIntegerBits>(num);
+ return gemmlowp::Rescale<OutputIntegerBits>(z_pow_2_adj * log_2 +
+ num_scaled * recip_denom);
+}
+
+// Minimum output bits to accommodate log of maximum input range. It actually
+// does not matter if one considers, say, [-64,64] or [-64,64).
+//
+// For example, run this through Octave:
+// [0:127; ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ...
+// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))]
+constexpr int min_log_x_output_bits(int input_bits) {
+ return input_bits > 90
+ ? 7
+ : input_bits > 44
+ ? 6
+ : input_bits > 21
+ ? 5
+ : input_bits > 10
+ ? 4
+ : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1;
+}
+
+template <int OutputIntegerBits, int InputIntegerBits>
+inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
+log_x_for_x_greater_than_or_equal_to_1(
+ gemmlowp::FixedPoint<int32, InputIntegerBits> input_val) {
+ static_assert(
+ OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits),
+ "Output integer bits must be sufficent to accommodate logs of inputs.");
+ return log_x_for_x_greater_than_or_equal_to_1_impl<OutputIntegerBits,
+ InputIntegerBits>(
+ input_val);
+}
+
inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
@@ -2681,13 +2930,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
- // TODO(b/77858996): Implement fixed-point log().
- // Not a fully-quantized implementation: floating-point log().
- const float float_log_sum_of_exps =
- std::log(static_cast<float>(sum_of_exps.raw()) /
- (1 << (31 - kAccumulationIntegerBits)));
- const int32 fixed_log_sum_of_exps = static_cast<int32>(TfLiteRound(
- float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits))));
+ const int32 fixed_log_sum_of_exps =
+ log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
+ sum_of_exps)
+ .raw();
// rescaled_diff_min is smallest representable in
// Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc
new file mode 100644
index 0000000000..c1c50dff4d
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace {
+void TestOneResizeBilinear(int batch, int depth, int input_width,
+ int input_height, int output_width,
+ int output_height) {
+ Dims<4> input_dims_inference =
+ MakeDimsForInference(depth, input_width, input_height, batch);
+ Dims<4> output_dims_inference =
+ MakeDimsForInference(depth, output_width, output_height, batch);
+
+ const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
+ const int output_buffer_size =
+ RequiredBufferSizeForDims(output_dims_inference);
+
+ std::vector<float> input_data(input_buffer_size, 0);
+ std::vector<float> reference_output_data(output_buffer_size, 0);
+ // Initialize the output data with something other than zero, so we can catch
+ // issue with kernels failing to initialize the output.
+ std::vector<float> output_data(output_buffer_size, 3.1415);
+
+ const float input_amplitude = 1.f;
+ FillRandom(&input_data, -input_amplitude, input_amplitude);
+
+ Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
+ std::vector<int32> output_size_data = {output_height, output_width};
+
+ reference_ops::ResizeBilinear(
+ input_data.data(), input_dims_inference, output_size_data.data(),
+ output_size_dims, reference_output_data.data(), output_dims_inference);
+ optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference,
+ output_size_data.data(), output_size_dims,
+ output_data.data(), output_dims_inference);
+
+ double sum_diff = 0;
+ float max_abs_val = 0;
+ for (int i = 0; i < output_buffer_size; i++) {
+ sum_diff += std::abs(output_data[i] - reference_output_data[i]);
+ max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
+ }
+
+ if (sum_diff != 0.f) {
+ const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
+ const float relative_error = std::abs(mean_diff) / max_abs_val;
+ ASSERT_LT(relative_error, 1e-5f);
+ }
+}
+
+TEST(ResizeBilinear, TestResizeBilinear) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+
+ TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
+ output_height);
+ }
+}
+
+TEST(ResizeBilinear2x2, TestResizeBilinear) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = input_width * 2;
+ const int output_height = input_height * 2;
+
+ TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
+ output_height);
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
new file mode 100644
index 0000000000..d781a7b642
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -0,0 +1,227 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+
+namespace tflite {
+namespace {
+
+void RunSoftmaxFloatReference(const uint8* input_data,
+ const Dims<4>& dims_common, int32 input_offset,
+ const double input_scale, int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ std::vector<float> reference_dequant_data(ref_buffer_size);
+ std::vector<float> reference_output_float_data(ref_buffer_size);
+
+ // Reference data generated via Dequant of input into float, and then applying
+ // float Softmax.
+ reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
+ reference_dequant_data.data(), dims_common);
+ optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta,
+ reference_output_float_data.data(), dims_common);
+ // Work with quantized scaling for Softmax, under which 256 represents 1, but
+ // we limit this to 255.
+ for (int i = 0; i < ref_buffer_size; i++) {
+ reference_output_data[i] = std::min(
+ 255,
+ static_cast<int>(std::round(256.0f * reference_output_float_data[i])));
+ }
+}
+
+void CheckOutputData(const uint8* test_output, const uint8* reference_output,
+ const Dims<4>& dims_common, const string& check_label,
+ bool be_exacting) {
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ // While calculating some metrics in floating point, we work with quantized
+ // scaling.
+ std::vector<int> diff(buffer_size);
+ int64_t sum_diff = 0;
+ int64_t sum_abs_diff = 0;
+ for (int i = 0; i < buffer_size; i++) {
+ diff[i] = static_cast<int>(test_output[i]) - reference_output[i];
+ sum_diff += diff[i];
+ sum_abs_diff += std::abs(diff[i]);
+ }
+ // These stats help understand test failures.
+ std::sort(std::begin(diff), std::end(diff));
+ const int min_diff = diff.front();
+ const int max_diff = diff.back();
+ const int median_diff = diff[diff.size() / 2];
+ const float mean_diff = static_cast<float>(sum_diff) / buffer_size;
+ const float mean_abs_diff = static_cast<float>(sum_abs_diff) / buffer_size;
+ // We either check for bit exactness (against the reference quantized version)
+ // or for general accuracy, allowing off-by-one (against the float reference).
+ if (be_exacting) {
+ ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0);
+ } else {
+ // For small numbers of samples, the estimates of the means vary more.
+ // Rather than widen the tolerances, we skip the smaller tests.
+ ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) ||
+ buffer_size < 10000) &&
+ std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
+ std::abs(max_diff) <= 1);
+ }
+}
+
+// Runs the Softmax and compares against the float reference implementation and
+// the quantized reference implementation.
+void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
+ int32 input_offset, const double input_scale, int stride,
+ float beta) {
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ std::vector<uint8> optimized_softmax_output(buffer_size);
+ std::vector<uint8> reference_float_softmax_output(buffer_size);
+ std::vector<uint8> reference_quant_softmax_output(buffer_size);
+
+ RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale,
+ stride, beta, reference_float_softmax_output.data());
+
+ int32 input_beta_multiplier;
+ int input_beta_left_shift;
+ static const int kScaledDiffIntegerBits = 5;
+ tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits,
+ &input_beta_multiplier,
+ &input_beta_left_shift);
+ // diff_min has a negative value, and is used to limit the maximum magnitude
+ // of the diffs, which are <= 0.
+ const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
+ input_beta_left_shift);
+
+ optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ input_beta_left_shift, diff_min,
+ optimized_softmax_output.data(), dims_common);
+ reference_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ input_beta_left_shift, diff_min,
+ reference_quant_softmax_output.data(), dims_common);
+
+ CheckOutputData(optimized_softmax_output.data(),
+ reference_float_softmax_output.data(), dims_common,
+ "Optimized vs float reference", false);
+ CheckOutputData(optimized_softmax_output.data(),
+ reference_quant_softmax_output.data(), dims_common,
+ "Optimized vs quant reference", true);
+ CheckOutputData(reference_quant_softmax_output.data(),
+ reference_float_softmax_output.data(), dims_common,
+ "Quant reference vs float reference", false);
+}
+
+// This function picks some random Softmax params, which are checked for
+// desirability. If not acceptable, it returns false. If they're OK,
+// it runs the Softmax test and returns true. This allows the caller
+// to loop until a test has been run.
+//
+// Currently we do not reject for any reason.
+bool TryOneUniformSoftmax() {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // Softmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
+
+ Dims<4> dims_common =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandom(&input_data);
+ RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ stride, beta);
+ return true;
+}
+
+// See TryOneUniformSoftmax() for a general description.
+//
+// Tests with "skyscraper" input patterns are included for two reasons. (a)
+// Bimodal distributions are potentially challenging and perhaps more
+// realistic than simple uniform random inputs. (b) Some implementations of
+// Softmax may adapt as they traverse the depth, and so we test handling of
+// cases where relatively small values are encountered at the beginning and end.
+bool TryOneSkyscraperSoftmax(bool small_depth) {
+ // We pick mostly positive values, on the whole emphasizing smaller values and
+ // therefore faster tests. We test a wider range of depths. In the case of
+ // Softmax, the width and height really just create test repetitions.
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int input_depth = small_depth
+ ? ExponentialRandomPositiveInt(0.75f, 40, 500)
+ : ExponentialRandomPositiveInt(0.75f, 175, 500);
+ const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200);
+ const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0));
+ const int32 input_offset = UniformRandomInt(-256, 0);
+ const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
+ // Extra parameters for skyscraper input patterns.
+ const double middle_proportion =
+ ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0);
+ const int middle_min = UniformRandomInt(0, 255);
+ const int sides_max = UniformRandomInt(0, middle_min);
+
+ Dims<4> dims_common =
+ MakeDimsForInference(input_depth, input_width, input_height, batch);
+ const int buffer_size = RequiredBufferSizeForDims(dims_common);
+
+ std::vector<uint8> input_data(buffer_size);
+ FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
+ sides_max);
+ RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ stride, beta);
+ return true;
+}
+
+TEST(TestQuantizedSoftmax, UniformSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneUniformSoftmax()) {
+ }
+ }
+}
+
+TEST(TestQuantizedSoftmax, SkyscraperSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperSoftmax(false)) {
+ }
+ }
+}
+
+TEST(TestQuantizedSoftmax, SmallSkyscraperSoftmaxTests) {
+ const int kTestsToRun = 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ while (!TryOneSkyscraperSoftmax(true)) {
+ }
+ }
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index e1c9ccd84b..5160e22307 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -23,6 +23,9 @@ namespace tensor_utils {
// Limit a float input f between +abs_limit and -abs_limit.
float Clip(float f, float abs_limit);
+// Checks if all entries of vector are zero.
+bool IsZeroVector(const float* vector, int v_size);
+
// Quantizes a buffer of floating point values using a symmetric quantization
// (i.e. linear quantization without an offset) to 8-bit signed integers.
// It also outputs the range (min, max) of the floating point buffer, and the
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 3d8a2eada0..14ee528394 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -32,6 +32,25 @@ TEST(uKernels, ClipTest) {
{0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
}
+TEST(uKernels, IsZeroTest) {
+ constexpr int kVectorSize = 21;
+ static float zeros[kVectorSize] = {0.0};
+ EXPECT_TRUE(IsZeroVector(zeros, kVectorSize));
+
+ static float nonzeros[kVectorSize] = {
+ 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12,
+ 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19,
+ 1e-20, 1e-21, 1e-22, 1e-23, 1e-24, 1e-25, 1e-26};
+ EXPECT_FALSE(IsZeroVector(nonzeros, kVectorSize));
+}
+
+TEST(uKernels, GeneratedIsZeroTest) {
+ constexpr int kVectorSize = 39;
+ std::vector<float> input(kVectorSize);
+ ZeroVector(input.data(), kVectorSize);
+ EXPECT_TRUE(IsZeroVector(input.data(), kVectorSize));
+}
+
TEST(uKernels, SymmetricQuantizeFloatsTest) {
constexpr int kVectorSize = 9;
static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0,
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
new file mode 100644
index 0000000000..9b1fd9b344
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -0,0 +1,121 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
+
+#include <cmath>
+#include <iterator>
+
+namespace tflite {
+
+Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) {
+ Dims<4> result;
+ int cum_prod = 1;
+
+ result.sizes[0] = depth;
+ result.strides[0] = cum_prod;
+ cum_prod *= result.sizes[0];
+
+ result.sizes[1] = width;
+ result.strides[1] = cum_prod;
+ cum_prod *= result.sizes[1];
+
+ result.sizes[2] = height;
+ result.strides[2] = cum_prod;
+ cum_prod *= result.sizes[2];
+
+ result.sizes[3] = batch;
+ result.strides[3] = cum_prod;
+
+ return result;
+}
+
+// this is a copied from an internal function in propagate_fixed_sizes.cc
+bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
+ int filter_height, int stride, PaddingType padding_type,
+ Dims<4>* output_dims, int* pad_width, int* pad_height) {
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_height = ArraySize(input_dims, 2);
+ const int batch = ArraySize(input_dims, 3);
+
+ int output_height = 0;
+ int output_width = 0;
+ if (padding_type == PaddingType::kValid) {
+ output_height = (input_height + stride - filter_height) / stride;
+ output_width = (input_width + stride - filter_width) / stride;
+ } else if (padding_type == PaddingType::kSame) {
+ output_height = (input_height + stride - 1) / stride;
+ output_width = (input_width + stride - 1) / stride;
+ } else {
+ return false;
+ }
+
+ if (output_width <= 0 || output_height <= 0) {
+ return false;
+ }
+
+ *pad_height =
+ ((output_height - 1) * stride + filter_height - input_height) / 2;
+ *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
+ *output_dims =
+ MakeDimsForInference(output_depth, output_width, output_height, batch);
+ return true;
+}
+
+std::mt19937& RandomEngine() {
+ static std::mt19937 engine;
+ return engine;
+}
+
+int UniformRandomInt(int min, int max) {
+ std::uniform_int_distribution<int> dist(min, max);
+ return dist(RandomEngine());
+}
+
+float UniformRandomFloat(float min, float max) {
+ std::uniform_real_distribution<float> dist(min, max);
+ return dist(RandomEngine());
+}
+
+int ExponentialRandomPositiveInt(float percentile, int percentile_val,
+ int max_val) {
+ const float lambda =
+ -std::log(1.f - percentile) / static_cast<float>(percentile_val);
+ std::exponential_distribution<float> dist(lambda);
+ float val;
+ do {
+ val = dist(RandomEngine());
+ } while (!val || !std::isfinite(val) || val > max_val);
+ return static_cast<int>(std::ceil(val));
+}
+
+float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
+ float max_val) {
+ const float lambda =
+ -std::log(1.f - percentile) / static_cast<float>(percentile_val);
+ std::exponential_distribution<float> dist(lambda);
+ float val;
+ do {
+ val = dist(RandomEngine());
+ } while (!std::isfinite(val) || val > max_val);
+ return val;
+}
+
+void FillRandom(std::vector<float>* vec, float min, float max) {
+ std::uniform_real_distribution<float> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(std::begin(*vec), std::end(*vec), gen);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
new file mode 100644
index 0000000000..26078cef49
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
+
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <vector>
+
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+// Creates a Dims struct from a set of dimensions.
+Dims<4> MakeDimsForInference(int depth, int width, int height, int batch);
+
+// Computes output and padding dimensions.
+bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
+ int filter_height, int stride, PaddingType padding_type,
+ Dims<4>* output_dims, int* pad_width, int* pad_height);
+
+// Returns a mt19937 random engine.
+std::mt19937& RandomEngine();
+
+// Returns a random integer uniformly distributed between |min| and |max|.
+int UniformRandomInt(int min, int max);
+
+// Returns a random float uniformly distributed between |min| and |max|.
+float UniformRandomFloat(float min, float max);
+
+// Returns a random element in |v|.
+template <typename T>
+const T& RandomElement(const std::vector<T>& v) {
+ return v[UniformRandomInt(0, v.size() - 1)];
+}
+
+// Returns a random exponentially distributed integer.
+int ExponentialRandomPositiveInt(float percentile, int percentile_val,
+ int max_val);
+
+// Returns a random exponentially distributed float.
+float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
+ float max_val);
+
+// Fills a vector with random floats between |min| and |max|.
+void FillRandom(std::vector<float>* vec, float min, float max);
+
+// Fills a vector with random numbers between |min| and |max|.
+template <typename T>
+void FillRandom(std::vector<T>* vec, T min, T max) {
+ std::uniform_int_distribution<T> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(std::begin(*vec), std::end(*vec), gen);
+}
+
+// Fills a vector with random numbers.
+template <typename T>
+void FillRandom(std::vector<T>* vec) {
+ FillRandom(vec, std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
+}
+
+template <typename T>
+void FillRandom(typename std::vector<T>::iterator begin_it,
+ typename std::vector<T>::iterator end_it, T min, T max) {
+ std::uniform_int_distribution<T> dist(min, max);
+ auto gen = std::bind(dist, RandomEngine());
+ std::generate(begin_it, end_it, gen);
+}
+
+// Fill with a "skyscraper" pattern, in which there is a central section (across
+// the depth) with higher values than the surround.
+template <typename T>
+void FillRandomSkyscraper(std::vector<T>* vec, int depth,
+ double middle_proportion, uint8 middle_min,
+ uint8 sides_max) {
+ for (auto base_it = std::begin(*vec); base_it != std::end(*vec);
+ base_it += depth) {
+ auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion));
+ auto right_it =
+ base_it + std::ceil(0.5 * depth * (1.0 + middle_proportion));
+ FillRandom(base_it, left_it, std::numeric_limits<T>::min(), sides_max);
+ FillRandom(left_it, right_it, middle_min, std::numeric_limits<T>::max());
+ FillRandom(right_it, base_it + depth, std::numeric_limits<T>::min(),
+ sides_max);
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 43c6883278..d5293edd56 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -20,6 +20,7 @@ limitations under the License.
namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
+enum class PaddingType { kNone, kSame, kValid };
// Quantization parameters, determining the mapping of quantized values
// to real values (i.e. determining how quantized values are mathematically
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 239b533a17..184028427f 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -37,7 +37,6 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <=
1e-6 * std::min(input_product_scale, bias_scale));
TF_LITE_ENSURE(context, input_product_scale >= 0);
- TF_LITE_ENSURE(context, input_product_scale < output_scale);
*multiplier = input_product_scale / output_scale;
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
index 11cc666bad..070ed60040 100644
--- a/tensorflow/contrib/lite/kernels/l2norm_test.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel {
int output_;
};
-TEST(L2NormOpTest, SimpleTest) {
+TEST(L2NormOpTest, SimpleFloatTest) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
@@ -76,7 +76,7 @@ TEST(L2NormOpTest, SimpleTest) {
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
-TEST(L2NormOpTest, MultipleBatchesTest) {
+TEST(L2NormOpTest, MultipleBatchFloatTest) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({
@@ -105,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) {
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
}
+TEST(L2NormOpTest, MultipleBatchUint8Test) {
+ L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+ m.QuantizeAndPopulate<uint8_t>(m.input(),
+ {
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({
+ 58, 166, 173, 205, 83, 134, // batch 1
+ 58, 166, 173, 205, 83, 134, // batch 2
+ 58, 166, 173, 205, 83, 134, // batch 3
+ }));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
+ },
+ 0.1)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 107c84e666..eed57d412b 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -155,7 +155,6 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
nn_type, static_cast<uint32_t>(tensor->dims->size),
reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
-
// TODO(aselle): Based on Michael's suggestion, limiting this to read
// only memory
if (tensor->allocation_type == kTfLiteMmapRo) {
@@ -168,7 +167,12 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
CHECK_NN(ANeuralNetworksModel_setOperandValue(
nn_model, next_id, tensor->data.raw, tensor->bytes));
}
+ } else if (tensor->bytes == 0) {
+ // These size 0 tensors are optional tensors reserved.
+ CHECK_NN(
+ ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0));
}
+
++next_id;
}
return next_id;
@@ -177,7 +181,9 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
// Adds the operations and their parameters to the NN API model.
// 'next-id' is the operand ID of the next operand of the model.
void AddOpsAndParams(tflite::Interpreter* interpreter,
- ANeuralNetworksModel* nn_model, uint32_t next_id) {
+ ANeuralNetworksModel* nn_model, uint32_t next_id,
+ std::vector<int>* model_state_inputs,
+ std::vector<int>* model_state_outputs) {
for (size_t i = 0; i < interpreter->nodes_size(); i++) {
const auto* node_and_registration = interpreter->node_and_registration(i);
const TfLiteNode& node = node_and_registration->first;
@@ -188,6 +194,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
// Add the parameters.
std::vector<uint32_t> augmented_inputs(
node.inputs->data, node.inputs->data + node.inputs->size);
+ std::vector<uint32_t> augmented_outputs(
+ node.outputs->data, node.outputs->data + node.outputs->size);
auto add_scalar_int32 = [&nn_model, &augmented_inputs,
&next_id](int value) {
@@ -207,12 +215,23 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
augmented_inputs.push_back(next_id++);
};
+ // Handle state tensors of RNN, LSTM, SVDF.
+ // For each state_out tensor, a corresponding state_in operand needs to be
+ // created for NNAPI.
auto duplicate_state_tensor_float32 =
- [interpreter, &nn_model, &augmented_inputs](int tensor_id) {
+ [interpreter, &nn_model, &next_id, &augmented_inputs,
+ &model_state_inputs, &model_state_outputs](int tensor_id) {
const TfLiteTensor* tensor = interpreter->tensor(tensor_id);
- CHECK_NN(ANeuralNetworksModel_setOperandValue(
- nn_model, tensor_id, tensor->data.raw, tensor->bytes));
- augmented_inputs.push_back(tensor_id);
+ ANeuralNetworksOperandType operand_type{
+ ANEURALNETWORKS_TENSOR_FLOAT32,
+ static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data),
+ tensor->params.scale, tensor->params.zero_point};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+ augmented_inputs.push_back(next_id);
+ model_state_inputs->push_back(next_id);
+ model_state_outputs->push_back(tensor_id);
+ next_id++;
};
auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
@@ -275,28 +294,51 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_scalar_float32(builtin->proj_clip);
};
+ // LSTM in NNAPI requires scratch tensor as an output operand.
+ auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model,
+ &next_id, &augmented_outputs]() {
+ int scratch_buffer_index = node.temporaries->data[0];
+ const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index);
+ ANeuralNetworksOperandType operand_type{
+ ANEURALNETWORKS_TENSOR_FLOAT32,
+ static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data), tensor->params.scale,
+ tensor->params.zero_point};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+ augmented_outputs.insert(augmented_outputs.begin(), next_id++);
+ };
+
auto add_mean_params = [&add_scalar_int32](void* data) {
auto builtin = reinterpret_cast<TfLiteMeanParams*>(data);
add_scalar_int32(builtin->keep_dims);
};
-#if 0
- auto add_reshape_params = [&](void* data) {
- auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
- uint32_t tensor_size_shape = builtin->num_dimensions;
- ANeuralNetworksOperandType operand_type{
- ANEURALNETWORKS_TENSOR_INT32,
- {static_cast<uint32_t>(1),
- reinterpret_cast<uint32_t*>(&tensor_size_shape)},
- 0,
- 0};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
- CHECK_NN(ANeuralNetworksModel_setOperandValue(
- nn_model, next_id, builtin->shape,
- sizeof(int) * builtin->num_dimensions));
- augmented_inputs.push_back(next_id++);
+ auto add_svdf_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteSVDFParams*>(data);
+ add_scalar_int32(builtin->rank);
+ add_scalar_int32(builtin->activation);
};
-#endif
+
+ auto add_rnn_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteRNNParams*>(data);
+ add_scalar_int32(builtin->activation);
+ };
+
+ // Handle optional input tensors.
+ auto add_optional_tensors = [&nn_model, &augmented_inputs,
+ &next_id](int nn_type) {
+ for (size_t idx = 0; idx < augmented_inputs.size(); idx++) {
+ if (augmented_inputs[idx] == kOptionalTensor) {
+ const std::vector<uint32_t> dim = {0, 0};
+ ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id,
+ nullptr, 0))
+ augmented_inputs[idx] = next_id++;
+ }
+ }
+ };
+
int nnapi_version = 10;
ANeuralNetworksOperationType nn_op_type;
@@ -366,13 +408,31 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
break;
case tflite::BuiltinOperator_LSTM: {
duplicate_state_tensor_float32(
- node.outputs->data[/*kOutputStateTensor*/ 1]);
+ node.outputs->data[/*kOutputStateTensor*/ 0]);
duplicate_state_tensor_float32(
- node.outputs->data[/*kCellStateTensor*/ 2]);
+ node.outputs->data[/*kCellStateTensor*/ 1]);
add_lstm_params(node.builtin_data);
+ add_lstm_scratch_tensor_float32();
+ add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32);
nn_op_type = ANEURALNETWORKS_LSTM;
break;
}
+ case tflite::BuiltinOperator_SVDF: {
+ duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]);
+ add_svdf_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_SVDF;
+ break;
+ }
+ case tflite::BuiltinOperator_RNN: {
+ duplicate_state_tensor_float32(
+ node.outputs->data[/*kHiddenStateTensor*/ 0]);
+ add_rnn_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_RNN;
+ break;
+ }
+ case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
+ nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP;
+ break;
case tflite::BuiltinOperator_PAD:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_PAD;
@@ -392,12 +452,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
- case tflite::BuiltinOperator_SVDF:
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
- case tflite::BuiltinOperator_RNN:
case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
- case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
@@ -450,8 +507,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
// Add the operation.
CHECK_NN(ANeuralNetworksModel_addOperation(
nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
- augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
- reinterpret_cast<uint32_t*>(node.outputs->data)));
+ augmented_inputs.data(),
+ static_cast<uint32_t>(augmented_outputs.size()),
+ reinterpret_cast<uint32_t*>(augmented_outputs.data())));
}
}
@@ -475,12 +533,25 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
}
uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list);
- AddOpsAndParams(interpreter, nn_model_, next_id);
+ AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
+ &model_states_outputs_);
+
+ std::vector<int> augmented_inputs = interpreter->inputs();
+ std::vector<int> augmented_outputs = interpreter->outputs();
+
+ // All state tensors input/output need to be treated as model input/output.
+ augmented_inputs.insert(augmented_inputs.end(),
+ model_states_inputs_.begin(),
+ model_states_inputs_.end());
+ augmented_outputs.insert(augmented_outputs.end(),
+ model_states_outputs_.begin(),
+ model_states_outputs_.end());
+
CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
- nn_model_, static_cast<uint32_t>(interpreter->inputs().size()),
- reinterpret_cast<const uint32_t*>(interpreter->inputs().data()),
- static_cast<uint32_t>(interpreter->outputs().size()),
- reinterpret_cast<const uint32_t*>(interpreter->outputs().data())));
+ nn_model_, static_cast<uint32_t>(augmented_inputs.size()),
+ reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
+ static_cast<uint32_t>(augmented_outputs.size()),
+ reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
}
if (!nn_compiled_model_) {
@@ -507,6 +578,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
CHECK_NN(ANeuralNetworksExecution_setInput(
execution, i, nullptr, tensor->data.raw, tensor->bytes));
}
+
// Tell nn api where to place final data.
for (size_t i = 0; i < interpreter->outputs().size(); i++) {
int output = interpreter->outputs()[i];
@@ -514,6 +586,24 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
CHECK_NN(ANeuralNetworksExecution_setOutput(
execution, i, nullptr, tensor->data.raw, tensor->bytes));
}
+
+ // The state_out of previous invocation need to be mapped to state_in of
+ // current invocation.
+ for (size_t i = 0; i < model_states_outputs_.size(); i++) {
+ int state_tensor_idx = model_states_outputs_[i];
+ TfLiteTensor* tensor = interpreter->tensor(state_tensor_idx);
+ // Here we are using a deep copy for state_in tensors so that we are not
+ // reading and writing into the same buffer during a invocation.
+ // TODO(miaowang): using double shared buffer to minimize the copies.
+ CHECK_NN(ANeuralNetworksExecution_setInput(
+ execution, i + interpreter->inputs().size(), nullptr, tensor->data.raw,
+ tensor->bytes));
+ // Tell NNAPI where to output the state_out.
+ CHECK_NN(ANeuralNetworksExecution_setOutput(
+ execution, i + interpreter->outputs().size(), nullptr, tensor->data.raw,
+ tensor->bytes));
+ }
+
// Currently use blocking compute.
ANeuralNetworksEvent* event = nullptr;
CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event));
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index e98000929a..94dea4f9b2 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -59,6 +59,14 @@ class NNAPIDelegate {
ANeuralNetworksModel* nn_model_ = nullptr;
// The NN API compilation handle
ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+
+ // List of state tensors for LSTM, RNN, SVDF.
+ // NN API does not allow ops to maintain states across multiple
+ // invocations. We need to manually create state input tensors from
+ // corresponding state output tensors of TFLite operations, and map them
+ // correctly.
+ std::vector<int> model_states_inputs_;
+ std::vector<int> model_states_outputs_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD
index 15999e5d41..c86be65ca7 100644
--- a/tensorflow/contrib/lite/profiling/BUILD
+++ b/tensorflow/contrib/lite/profiling/BUILD
@@ -31,6 +31,33 @@ cc_library(
copts = common_copts,
)
+cc_library(
+ name = "profile_summarizer",
+ srcs = ["profile_summarizer.cc"],
+ hdrs = ["profile_summarizer.h"],
+ deps = [
+ ":profiler",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:stats_calculator_portable",
+ ],
+)
+
+cc_test(
+ name = "profile_summarizer_test",
+ srcs = ["profile_summarizer_test.cc"],
+ deps = [
+ ":profile_summarizer",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
cc_test(
name = "profile_buffer_test",
srcs = ["profile_buffer_test.cc"],
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
new file mode 100644
index 0000000000..788f6922d2
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+
+#include <sstream>
+
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace profiling {
+namespace {
+
+using Detail = tensorflow::StatsCalculator::Detail;
+
+struct OperatorDetails {
+ string name;
+ std::vector<string> inputs;
+ std::vector<string> outputs;
+};
+
+string GetTensorName(const tflite::Interpreter& interpreter, int tensor_index) {
+ const auto tensor = interpreter.tensor(tensor_index);
+ if (tensor == nullptr || tensor->name == nullptr) {
+ return "Unknown";
+ }
+ return tensor->name;
+}
+std::vector<string> GetTensorNames(const tflite::Interpreter& interpreter,
+ const TfLiteIntArray* tensor_indices) {
+ std::vector<string> tensors;
+ tensors.reserve(tensor_indices->size);
+ for (int i = 0; i < tensor_indices->size; i++) {
+ tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i]));
+ }
+ return tensors;
+}
+
+string ToString(const std::vector<string>& str_vector) {
+ std::stringstream stream;
+ stream << "[";
+ bool first = true;
+ for (const auto& s : str_vector) {
+ if (!first) {
+ stream << ", ";
+ } else {
+ first = false;
+ }
+ stream << s;
+ }
+ stream << "]";
+ return stream.str();
+}
+
+OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
+ int node_index) {
+ auto node_reg = interpreter.node_and_registration(node_index);
+ auto inputs = node_reg->first.inputs;
+ auto outputs = node_reg->first.outputs;
+ int code = node_reg->second.builtin_code;
+ const char* op_name = nullptr;
+ if (code == tflite::BuiltinOperator_CUSTOM) {
+ const char* custom_name = node_reg->second.custom_name;
+ op_name = custom_name ? custom_name : "UnknownCustomOp";
+ } else {
+ op_name = tflite::EnumNamesBuiltinOperator()[code];
+ }
+ OperatorDetails details;
+ details.name = op_name;
+ details.inputs = GetTensorNames(interpreter, inputs);
+ details.outputs = GetTensorNames(interpreter, outputs);
+ return details;
+}
+
+} // namespace
+
+ProfileSummarizer::ProfileSummarizer()
+ : stats_calculator_(new ::tensorflow::StatsCalculator(
+ tensorflow::StatSummarizerOptions())) {}
+
+void ProfileSummarizer::ProcessProfiles(
+ const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter) {
+ std::vector<const ProfileEvent*> events;
+ std::copy_if(profile_stats.begin(), profile_stats.end(),
+ std::back_inserter(events), [](const ProfileEvent* e) {
+ return e->event_type ==
+ ProfileEvent::EventType::OPERATOR_INVOKE_EVENT &&
+ e->end_timestamp_us >= e->begin_timestamp_us;
+ });
+ // Sort with begin_time.
+ std::sort(events.begin(), events.end(),
+ [](const ProfileEvent* const& a, const ProfileEvent* const& b) {
+ return a->begin_timestamp_us < b->begin_timestamp_us;
+ });
+ if (events.empty()) {
+ return;
+ }
+
+ int64_t base_start_us = events[0]->begin_timestamp_us;
+ int node_num = 0;
+ int64_t curr_total_us = 0;
+ std::map<std::string, Detail> details;
+ for (auto event : events) {
+ auto op_details = GetOperatorDetails(interpreter, event->event_metadata);
+ auto node_name = ToString(op_details.outputs);
+ auto result = details.emplace(node_name, Detail());
+ Detail* detail = &(result.first->second);
+ detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us);
+ int64_t node_exec_time =
+ event->end_timestamp_us - event->begin_timestamp_us;
+ detail->rel_end_us.UpdateStat(node_exec_time);
+ curr_total_us += node_exec_time;
+ ++node_num;
+
+ if (result.second) {
+ detail->name = node_name;
+ detail->type = op_details.name;
+ detail->run_order = node_num;
+ detail->times_called = 0;
+ }
+ ++detail->times_called;
+ }
+ stats_calculator_->UpdateDetails(details);
+ stats_calculator_->UpdateRunTotalUs(curr_total_us);
+}
+} // namespace profiling
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h
new file mode 100644
index 0000000000..6fe6ca04f5
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/core/util/stats_calculator.h"
+
+namespace tflite {
+namespace profiling {
+
+// Creates a summary of operator invocations in the interpreter.
+class ProfileSummarizer {
+ public:
+ ProfileSummarizer();
+ virtual ~ProfileSummarizer() {}
+
+ // Process profile events to update statistics for operator invocations.
+ void ProcessProfiles(const std::vector<const ProfileEvent*>& profile_stats,
+ const tflite::Interpreter& interpreter);
+
+ // Returns a string detailing the accumulated runtime stats in a tab-separated
+ // format which can be pasted into a spreadsheet for further analysis.
+ std::string GetOutputString() const {
+ return stats_calculator_->GetOutputString();
+ }
+
+ std::string GetShortSummary() const {
+ return stats_calculator_->GetShortSummary();
+ }
+
+ // Prints the string returned by GetOutputString().
+ void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
+
+ private:
+ std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
+};
+
+} // namespace profiling
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
new file mode 100644
index 0000000000..35cf780713
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+namespace profiling {
+
+namespace {
+
+TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
+ const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
+
+ TfLiteTensor* output = GetOutput(context, node, /*index=*/0);
+
+ int32_t* output_data = output->data.i32;
+ *output_data = *(input1->data.i32) + *(input2->data.i32);
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* RegisterSimpleOp() {
+ static TfLiteRegistration registration = {nullptr,
+ nullptr,
+ nullptr,
+ SimpleOpEval,
+ tflite::BuiltinOperator_CUSTOM,
+ "SimpleOpEval",
+ 1};
+ return &registration;
+}
+
+class SimpleOpModel : public SingleOpModel {
+ public:
+ void Init();
+ tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
+ void SetInputs(int32_t x, int32_t y) {
+ PopulateTensor(inputs_[0], {x});
+ PopulateTensor(inputs_[1], {y});
+ }
+ int32_t GetOutput() { return ExtractVector<int32_t>(output_)[0]; }
+
+ private:
+ int inputs_[2];
+ int output_;
+};
+
+void SimpleOpModel::Init() {
+ inputs_[0] = AddInput({TensorType_INT32, {1}});
+ inputs_[1] = AddInput({TensorType_INT32, {1}});
+ output_ = AddOutput({TensorType_INT32, {}});
+ SetCustomOp("SimpleAdd", {}, RegisterSimpleOp);
+ BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
+}
+
+TEST(ProfileSummarizerTest, Empty) {
+ ProfileSummarizer summarizer;
+ std::string output = summarizer.GetOutputString();
+ EXPECT_GT(output.size(), 0);
+}
+
+#ifdef TFLITE_PROFILING_ENABLED
+TEST(ProfileSummarizerTest, Interpreter) {
+ Profiler profiler;
+ SimpleOpModel m;
+ m.Init();
+ auto interpreter = m.GetInterpreter();
+ interpreter->SetProfiler(&profiler);
+ profiler.StartProfiling();
+ m.SetInputs(1, 2);
+ m.Invoke();
+ // 3 = 1 + 2
+ EXPECT_EQ(m.GetOutput(), 3);
+ profiler.StopProfiling();
+ ProfileSummarizer summarizer;
+ auto events = profiler.GetProfileEvents();
+ EXPECT_EQ(1, events.size());
+ summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
+ auto output = summarizer.GetOutputString();
+ // TODO(shashishekhar): Add a better test here.
+ ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output;
+}
+#endif
+
+} // namespace
+} // namespace profiling
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 4920e83970..a40e512045 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -45,7 +45,22 @@ py_library(
":convert",
":convert_saved_model",
":interpreter",
+ ":lite_constants",
":op_hint",
+ "//tensorflow/contrib/saved_model:saved_model_py",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ ],
+)
+
+py_test(
+ name = "lite_test",
+ srcs = ["lite_test.py"],
+ data = [":interpreter_test_data"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":lite",
],
)
@@ -110,10 +125,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":convert",
- ":lite_constants",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
+ "//tensorflow/python:platform",
"//tensorflow/python/tools:freeze_graph_lib",
],
)
@@ -151,15 +165,6 @@ py_test(
],
)
-py_binary(
- name = "convert_saved_model_to_frozen_graph",
- srcs = ["convert_saved_model_to_frozen_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convert_saved_model",
- ],
-)
-
# Transitive dependencies of this target will be included in the pip package.
py_library(
name = "tf_lite_py_pip",
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index a7eddf3408..54fec9d61f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import convert
-from tensorflow.contrib.lite.python import lite_constants
-from tensorflow.contrib.lite.toco import model_flags_pb2
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
@@ -110,12 +107,12 @@ def _get_signature_def(meta_graph, signature_key):
signature_def_map = meta_graph.signature_def
signature_def_keys = set(signature_def_map.keys())
logging.info(
- "The given saved_model MetaGraphDef contains SignatureDefs with the "
+ "The given SavedModel MetaGraphDef contains SignatureDefs with the "
"following keys: %s", signature_def_keys)
if signature_key not in signature_def_keys:
- raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible "
- "values are '{}'. ".format(signature_key,
- signature_def_keys))
+ raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
+ "values are '{}'.".format(signature_key,
+ ",".join(signature_def_keys)))
signature_def = signature_def_utils.get_signature_def_by_key(
meta_graph, signature_key)
return signature_def
@@ -207,8 +204,8 @@ def _get_tensors(graph, signature_def_tensor_names=None,
return tensors
-def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
- output_arrays, tag_set, signature_key, batch_size):
+def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
@@ -224,8 +221,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present. (default "serve")
signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
Returns:
frozen_graph_def: Frozen GraphDef.
@@ -237,7 +232,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
- input_shapes has a None value after the 1st dimension.
input_arrays or output_arrays are not valid.
Unable to load Session.
"""
@@ -246,8 +240,6 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if tag_set is None:
tag_set = set([tag_constants.SERVING])
- if batch_size is None:
- batch_size = 1
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
@@ -264,23 +256,13 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
- # Gets fully defined tensor shape. An input tensor with None in the first
- # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size.
- # Shapes with None after the first dimension result in a ValueError.
- # TODO(zhixianyan): Add supports for input tensor with more None in shape.
+ # Gets fully defined tensor shape.
for tensor in in_tensors:
if (input_shapes and tensor.name in input_shapes and
input_shapes[tensor.name] is not None):
shape = input_shapes[tensor.name]
else:
shape = tensor.get_shape().as_list()
-
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(tensor.name, shape))
- elif shape[0] is None:
- shape[0] = batch_size
tensor.set_shape(shape)
output_names = [node.split(":")[0] for node in outputs]
@@ -289,133 +271,3 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
return frozen_graph_def, in_tensors, out_tensors
raise ValueError("Unable to load Session.")
-
-
-def saved_model_to_frozen_graphdef(
- saved_model_dir,
- output_file_model,
- output_file_flags,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1):
- """Converts a SavedModel to a frozen graph. Writes graph to tmp directory.
-
- Stores frozen graph and command line flags in the tmp directory.
-
- Args:
- saved_model_dir: SavedModel directory to convert.
- output_file_model: Full file path to save frozen graph.
- output_file_flags: Full file path to save ModelFlags.
- input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
- Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
- output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
- tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
- signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
-
- Returns: None.
-
- Raises:
- ValueError: Unable to convert to frozen graph.
- """
- frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
- saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
- signature_key, batch_size)
-
- # Initialize model flags.
- model = model_flags_pb2.ModelFlags()
-
- for input_tensor in in_tensors:
- input_array = model.input_arrays.add()
- input_array.name = convert.tensor_name(input_tensor)
- input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
-
- for output_tensor in out_tensors:
- model.output_arrays.append(convert.tensor_name(output_tensor))
-
- # Write model and ModelFlags to file. ModelFlags contain input array and
- # output array information that is parsed from the SignatureDef and used for
- # analysis by TOCO.
- _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString())
- _write_and_flush_file(output_file_flags, model.SerializeToString())
-
-
-def tflite_from_saved_model(
- saved_model_dir,
- output_file=None,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1,
- inference_type=lite_constants.FLOAT,
- input_format=lite_constants.TENSORFLOW_GRAPHDEF,
- output_format=lite_constants.TFLITE,
- quantized_input_stats=None,
- drop_control_dependency=True):
- """Converts a SavedModel to TFLite FlatBuffer.
-
- Args:
- saved_model_dir: SavedModel directory to convert.
- output_file: File path to write result TFLite FlatBuffer.
- input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
- Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
- output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
- tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
- signature_key: Key identifying SignatureDef containing inputs and outputs.
- batch_size: Batch size for the model. Replaces the first dimension of an
- input size array if undefined. (default 1)
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
-
- Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
-
- Raises:
- ValueError: Unable to convert to frozen graph.
- """
- frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
- saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
- signature_key, batch_size)
-
- result = convert.toco_convert(
- input_data=frozen_graph_def,
- input_tensors=in_tensors,
- output_tensors=out_tensors,
- inference_type=inference_type,
- input_format=input_format,
- output_format=output_format,
- quantized_input_stats=quantized_input_stats,
- drop_control_dependency=drop_control_dependency)
-
- if output_file is not None:
- with gfile.Open(output_file, "wb") as f:
- f.write(result)
- logging.info("Successfully converted to: %s", output_file)
-
- return result
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index db95fc8ad7..f69381d0e6 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -25,12 +25,12 @@ from __future__ import print_function
import os
from tensorflow.contrib.lite.python import convert_saved_model
-from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib as estimator
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.layers import layers
from tensorflow.python.ops import array_ops
@@ -38,13 +38,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.losses import losses
-from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as train
-class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
+class FreezeSavedModelTest(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
"""Create a simple SavedModel on the fly."""
@@ -57,82 +57,163 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return saved_model_dir
+ def _createSavedModelTwoInputArrays(self, shape):
+ """Create a simple SavedModel."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
+ with session.Session() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputB")
+ in_tensor_2 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputA")
+ out_tensor = in_tensor_1 + in_tensor_2
+ inputs = {"x": in_tensor_1, "y": in_tensor_2}
+ outputs = {"z": out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ return saved_model_dir
+
+ def _getArrayNames(self, tensors):
+ return [tensor.name for tensor in tensors]
+
+ def _getArrayShapes(self, tensors):
+ dims = []
+ for tensor in tensors:
+ dim_tensor = []
+ for dim in tensor.shape:
+ if isinstance(dim, tensor_shape.Dimension):
+ dim_tensor.append(dim.value)
+ else:
+ dim_tensor.append(dim)
+ dims.append(dim_tensor)
+ return dims
+
+ def _convertSavedModel(self,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
+ graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=input_arrays,
+ input_shapes=input_shapes,
+ output_arrays=output_arrays,
+ tag_set=tag_set,
+ signature_key=signature_key)
+ return graph_def, in_tensors, out_tensors
+
def testSimpleSavedModel(self):
- """Test a simple SavedModel created on the fly."""
- # Create a simple SavedModel
+ """Test a SavedModel."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- # Convert to tflite
- result = convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir)
- self.assertTrue(result)
+ _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
def testSimpleSavedModelWithNoneBatchSizeInShape(self):
- """Test a simple SavedModel, with None in input tensor's shape."""
+ """Test a SavedModel with None in input tensor's shape."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
- result = convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir)
- self.assertTrue(result)
+ _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
- def testSimpleSavedModelWithMoreNoneInShape(self):
- """Test a simple SavedModel, fail as more None in input shape."""
- saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3])
- # Convert to tflite: this should raise ValueError, as 3rd dim is None.
- with self.assertRaises(ValueError):
- convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir)
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
- def testSimpleSavedModelWithWrongSignatureKey(self):
- """Test a simple SavedModel, fail as given signature is invalid."""
+ def testSimpleSavedModelWithInvalidSignatureKey(self):
+ """Test a SavedModel that fails due to an invalid signature_key."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- # Convert to tflite: this should raise ValueError, as
- # signature_key does not exit in the saved_model.
- with self.assertRaises(ValueError):
- convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir, signature_key="wrong-key")
-
- def testSimpleSavedModelWithWrongOutputArray(self):
- """Test a simple SavedModel, fail as given output_arrays is invalid."""
- # Create a simple SavedModel
+ with self.assertRaises(ValueError) as error:
+ self._convertSavedModel(saved_model_dir, signature_key="invalid-key")
+ self.assertEqual(
+ "No 'invalid-key' in the SavedModel's SignatureDefs. "
+ "Possible values are 'serving_default'.", str(error.exception))
+
+ def testSimpleSavedModelWithInvalidOutputArray(self):
+ """Test a SavedModel that fails due to invalid output arrays."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- # Convert to tflite: this should raise ValueError, as
- # output_arrays is not valid for the saved_model.
- with self.assertRaises(ValueError):
- convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir, output_arrays=["wrong-output"])
+ with self.assertRaises(ValueError) as error:
+ self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"])
+ self.assertEqual("Invalid tensors 'invalid-output' were found.",
+ str(error.exception))
def testSimpleSavedModelWithWrongInputArrays(self):
- """Test a simple SavedModel, fail as given input_arrays is invalid."""
+ """Test a SavedModel that fails due to invalid input arrays."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- # Checks invalid input_arrays.
- with self.assertRaises(ValueError):
- convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir, input_arrays=["wrong-input"])
- # Checks valid and invalid input_arrays.
- with self.assertRaises(ValueError):
- convert_saved_model.tflite_from_saved_model(
- saved_model_dir=saved_model_dir,
- input_arrays=["Placeholder", "wrong-input"])
+
+ # Check invalid input_arrays.
+ with self.assertRaises(ValueError) as error:
+ self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ # Check valid and invalid input_arrays.
+ with self.assertRaises(ValueError) as error:
+ self._convertSavedModel(
+ saved_model_dir, input_arrays=["Placeholder", "invalid-input"])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
def testSimpleSavedModelWithCorrectArrays(self):
- """Test a simple SavedModel, with correct input_arrays and output_arrays."""
+ """Test a SavedModel with correct input_arrays and output_arrays."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
- result = convert_saved_model.tflite_from_saved_model(
+ _, in_tensors, out_tensors = self._convertSavedModel(
saved_model_dir=saved_model_dir,
input_arrays=["Placeholder"],
output_arrays=["add"])
- self.assertTrue(result)
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
def testSimpleSavedModelWithCorrectInputArrays(self):
- """Test a simple SavedModel, with correct input_arrays and input_shapes."""
+ """Test a SavedModel with correct input_arrays and input_shapes."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- result = convert_saved_model.tflite_from_saved_model(
+ _, in_tensors, out_tensors = self._convertSavedModel(
saved_model_dir=saved_model_dir,
input_arrays=["Placeholder"],
input_shapes={"Placeholder": [1, 16, 16, 3]})
- self.assertTrue(result)
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
+
+ def testTwoInputArrays(self):
+ """Test a simple SavedModel."""
+ saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
+
+ _, in_tensors, out_tensors = self._convertSavedModel(
+ saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"])
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"])
+ self.assertEqual(
+ self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]])
+
+ def testSubsetInputArrays(self):
+ """Test a SavedModel with a subset of the input array names of the model."""
+ saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
+
+ # Check case where input shape is given.
+ _, in_tensors, out_tensors = self._convertSavedModel(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["inputA"],
+ input_shapes={"inputA": [1, 16, 16, 3]})
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
+
+ # Check case where input shape is None.
+ _, in_tensors, out_tensors = self._convertSavedModel(
+ saved_model_dir=saved_model_dir, input_arrays=["inputA"])
+
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
def testMultipleMetaGraphDef(self):
- """Test saved model with multiple MetaGraphDef."""
+ """Test saved model with multiple MetaGraphDefs."""
saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd")
builder = saved_model.builder.SavedModelBuilder(saved_model_dir)
with session.Session(graph=ops.Graph()) as sess:
@@ -161,91 +242,13 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
builder.save(True)
# Convert to tflite
- convert_saved_model.tflite_from_saved_model(
+ _, in_tensors, out_tensors = self._convertSavedModel(
saved_model_dir=saved_model_dir,
tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
-
-class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase):
-
- def _createSimpleSavedModel(self, shape):
- """Create a simple SavedModel."""
- saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
- with session.Session() as sess:
- in_tensor_1 = array_ops.placeholder(
- shape=shape, dtype=dtypes.float32, name="inputB")
- in_tensor_2 = array_ops.placeholder(
- shape=shape, dtype=dtypes.float32, name="inputA")
- out_tensor = in_tensor_1 + in_tensor_2
- inputs = {"x": in_tensor_1, "y": in_tensor_2}
- outputs = {"z": out_tensor}
- saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
- return saved_model_dir
-
- def _getInputArrayNames(self, model_proto):
- return [data.name for data in model_proto.input_arrays]
-
- def _getInputArrayShapes(self, model_proto):
- return [
- [dim for dim in data.shape.dims] for data in model_proto.input_arrays
- ]
-
- def _get_model_flags_proto_from_file(self, filename):
- proto = _model_flags_pb2.ModelFlags()
- with gfile.Open(filename, "rb") as output_file:
- proto.ParseFromString(output_file.read())
- output_file.close()
- return proto
-
- def testSimpleSavedModel(self):
- """Test a simple SavedModel."""
- saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
- output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
-
- convert_saved_model.saved_model_to_frozen_graphdef(
- saved_model_dir=saved_model_dir,
- output_file_model=output_file_model,
- output_file_flags=output_file_flags,
- input_arrays=["inputB", "inputA"])
-
- proto = self._get_model_flags_proto_from_file(output_file_flags)
- self.assertEqual(proto.output_arrays, ["add"])
- self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"])
- self.assertEqual(
- self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]])
-
- def testSimpleSavedModelWithDifferentInputNames(self):
- """Test a simple SavedModel."""
- saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
- output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
- output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
-
- # Check case where input shape is given.
- convert_saved_model.saved_model_to_frozen_graphdef(
- saved_model_dir=saved_model_dir,
- output_file_model=output_file_model,
- output_file_flags=output_file_flags,
- input_arrays=["inputA"],
- input_shapes={"inputA": [1, 16, 16, 3]})
-
- proto = self._get_model_flags_proto_from_file(output_file_flags)
- self.assertEqual(proto.output_arrays, ["add"])
- self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
- self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
-
- # Check case where input shape is None.
- convert_saved_model.saved_model_to_frozen_graphdef(
- saved_model_dir=saved_model_dir,
- output_file_model=output_file_model,
- output_file_flags=output_file_flags,
- input_arrays=["inputA"],
- input_shapes={"inputA": None})
-
- proto = self._get_model_flags_proto_from_file(output_file_flags)
- self.assertEqual(proto.output_arrays, ["add"])
- self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
- self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
+ self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
+ self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
+ self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]])
class Model(keras.Model):
@@ -354,7 +357,7 @@ def dummy_input_fn():
return image, labels
-class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
+class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
def testTrainedMnistSavedModel(self):
"""Test mnist SavedModel, trained with dummy data and small steps."""
@@ -379,13 +382,16 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
# Convert to tflite and test output
saved_model_name = os.listdir(saved_model_dir)[0]
saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name)
- output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite")
+
# TODO(zhixianyan): no need to limit output_arrays to `Softmax'
# once b/74205001 fixed and argmax implemented in tflite.
- result = convert_saved_model.tflite_from_saved_model(
+ result = convert_saved_model.freeze_saved_model(
saved_model_dir=saved_model_final_dir,
+ input_arrays=None,
+ input_shapes=None,
output_arrays=["Softmax"],
- output_file=output_file)
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
self.assertTrue(result)
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
deleted file mode 100644
index 4d9782f4a6..0000000000
--- a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Python console command for generating frozen models from SavedModels.
-
-This exists to add SavedModel compatibility to TOCO.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import sys
-from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef
-from tensorflow.python.platform import app
-
-FLAGS = None
-
-
-def execute(unused_args):
- """Calls function to convert the SavedModel to a frozen graph."""
- # Error handling.
- if FLAGS.input_shapes and not FLAGS.input_arrays:
- raise ValueError("Input shapes requires input arrays to be specified.")
-
- # Calls saved_model_to_frozen_graphdef function to generate frozen graph.
- input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None)
- input_shapes = None
- if FLAGS.input_shapes:
- input_shapes = {
- input_arrays[idx]: shape.split(",")
- for idx, shape in enumerate(FLAGS.input_shapes.split(":"))
- }
- output_arrays = (
- FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None)
- tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None
-
- saved_model_to_frozen_graphdef(
- saved_model_dir=FLAGS.saved_model_directory,
- output_file_model=FLAGS.output_file_model,
- output_file_flags=FLAGS.output_file_flags,
- input_arrays=input_arrays,
- input_shapes=input_shapes,
- output_arrays=output_arrays,
- tag_set=tag_set,
- signature_key=FLAGS.signature_key,
- batch_size=FLAGS.batch_size)
-
-
-def main():
- global FLAGS
- # Parses flags.
- parser = argparse.ArgumentParser(
- description="Invoke SavedModel to frozen model converter.")
- parser.add_argument(
- "saved_model_directory",
- type=str,
- help="Full path to directory containing the SavedModel.")
- parser.add_argument(
- "output_file_model",
- type=str,
- help="Full file path to save frozen graph.")
- parser.add_argument(
- "output_file_flags", type=str, help="Full file path to save ModelFlags.")
- parser.add_argument(
- "--input_arrays",
- type=str,
- help="Name of the input arrays, comma-separated.")
- parser.add_argument(
- "--input_shapes",
- type=str,
- help="Shapes corresponding to --input_arrays, colon-separated.")
- parser.add_argument(
- "--output_arrays",
- type=str,
- help="Name of the output arrays, comma-separated.")
- parser.add_argument(
- "--tag_set", type=str, help="Name of output arrays, comma-separated.")
- parser.add_argument(
- "--signature_key",
- type=str,
- help="Key identifying SignatureDef containing inputs and outputs.")
- parser.add_argument(
- "--batch_size",
- type=int,
- help="Batch size for the model. Replaces the first dimension of an "
- "input size array if undefined.")
-
- FLAGS, unparsed = parser.parse_known_args()
-
- app.run(main=execute, argv=[sys.argv[0]] + unparsed)
-
-
-if __name__ == "__main__":
- main()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
index 453eda6e73..12ab38847d 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
@@ -15,7 +15,7 @@ cc_library(
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/core:lib",
"//tensorflow/python:numpy_lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
],
)
@@ -27,6 +27,6 @@ tf_py_wrap_cc(
],
deps = [
":interpreter_wrapper_lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 16f4f30b94..6b12c91924 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -42,6 +42,8 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
return nullptr;
}
+ tensorflow::ImportNumpy();
+
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (interpreter) {
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 86b25e68ac..f7f2d40a02 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -16,23 +16,199 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
+@@TocoConverter
@@toco_convert
@@toco_convert_protos
-@@tflite_from_saved_model
@@Interpreter
@@OpHint
@@convert_op_hints_to_stubs
+@@FLOAT
+@@QUANTIZED_UINT8
+@@TFLITE
+@@GRAPHVIZ_DOT
+
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import
+from tensorflow.contrib.lite.python import lite_constants as constants
+from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
-from tensorflow.contrib.lite.python.convert import toco_convert_protos
-from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model
-from tensorflow.contrib.lite.python.interpreter import Interpreter
-from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
-from tensorflow.contrib.lite.python.op_hint import OpHint
-# pylint: enable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
+from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
+from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.ops.variables import global_variables_initializer
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
+
+
+class TocoConverter(object):
+ """Convert a TensorFlow model into `output_format` using TOCO.
+
+ This is used to convert from a TensorFlow GraphDef or SavedModel into either a
+ TFLite FlatBuffer or graph visualization.
+
+ Attributes:
+
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ (default FLOAT)
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT). (default TFLITE)
+ quantized_input_stats: The mean and std deviation of training data for each
+ input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
+ (default None)
+ drop_control_dependency: Boolean indicating whether to drop control
+ dependencies silently. This is due to TFLite not supporting control
+ dependencies. (default True)
+ allow_custom_ops: Boolean indicating whether to allow custom operations.
+ (default False)
+
+ Example usage:
+
+ # Converting a frozen graph.
+ converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
+
+ # Converting a SavedModel.
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ """
+
+ def __init__(self, graph_def, input_tensors, output_tensors):
+ """Constructor for TocoConverter.
+
+ Args:
+
+ graph_def: TensorFlow GraphDef.
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ """
+ self._graph_def = graph_def
+ self._input_tensors = input_tensors
+ self._output_tensors = output_tensors
+ self.inference_type = constants.FLOAT
+ self.output_format = constants.TFLITE
+ self.quantized_input_stats = None
+ self.drop_control_dependency = True
+ self.allow_custom_ops = False
+
+ @classmethod
+ def from_session(cls,
+ sess,
+ input_tensors,
+ output_tensors,
+ freeze_variables=False):
+ """Creates a TocoConverter class from a TensorFlow Session.
+
+ Args:
+ sess: TensorFlow Session.
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ freeze_variables: Boolean indicating whether the variables need to be
+ converted into constants via the freeze_graph.py script.
+ (default False)
+
+ Returns:
+ TocoConverter class.
+ """
+
+ # Get GraphDef.
+ if freeze_variables:
+ sess.run(global_variables_initializer())
+ output_arrays = [tensor_name(tensor) for tensor in output_tensors]
+ graph_def = tf_graph_util.convert_variables_to_constants(
+ sess, sess.graph_def, output_arrays)
+ else:
+ graph_def = sess.graph_def
+
+ # Create TocoConverter class.
+ return cls(graph_def, input_tensors, output_tensors)
+
+ @classmethod
+ def from_saved_model(
+ cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
+ """Creates a TocoConverter class from a SavedModel.
+
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ input_arrays: List of input tensors to freeze graph with. Uses input
+ arrays from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+
+ Returns:
+ TocoConverter class.
+ """
+ if tag_set is None:
+ tag_set = set([tag_constants.SERVING])
+
+ result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key)
+ return cls(
+ graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
+
+ def convert(self):
+ """Converts a TensorFlow GraphDef based on instance variables.
+
+ Returns:
+ The converted data in serialized format. Either a TFLite Flatbuffer or a
+ Graphviz graph depending on value in `output_format`.
+
+ Raises:
+ ValueError:
+ None value for dimension in input_tensor.
+ """
+ # Checks dimensions in input tensor.
+ for tensor in self._input_tensors:
+ shape = tensor.get_shape().as_list()
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(tensor.name, shape))
+ elif shape[0] is None:
+ self._set_batch_size(batch_size=1)
+
+ # Converts model.
+ result = toco_convert(
+ input_data=self._graph_def,
+ input_tensors=self._input_tensors,
+ output_tensors=self._output_tensors,
+ inference_type=self.inference_type,
+ input_format=constants.TENSORFLOW_GRAPHDEF,
+ output_format=self.output_format,
+ quantized_input_stats=self.quantized_input_stats,
+ drop_control_dependency=self.drop_control_dependency)
+ return result
+
+ def _set_batch_size(self, batch_size):
+ """Sets the first dimension of the input tensor to `batch_size`.
+
+ Args:
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+ """
+ for tensor in self._input_tensors:
+ shape = tensor.get_shape().as_list()
+ shape[0] = batch_size
+ tensor.set_shape(shape)
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
new file mode 100644
index 0000000000..2f3105f3e6
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -0,0 +1,323 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lite.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import saved_model
+
+
+class FromSessionTest(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testQuantization(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+ out_tensor = array_ops.fake_quant_with_min_max_args(
+ in_tensor + in_tensor, min=0., max=1., name='output')
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((1., 0.),
+ input_details[0]['quantization']) # scale, zero_point
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+
+ def testBatchSizeInvalid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[None, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Test invalid shape. None after 1st dimension.
+ in_tensor = array_ops.placeholder(
+ shape=[1, None, 16, 3], dtype=dtypes.float32)
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual(
+ 'None is only supported in the 1st dimension. Tensor '
+ '\'Placeholder_1:0\' has invalid shape \'[1, None, 16, 3]\'.',
+ str(error.exception))
+
+ def testBatchSizeValid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[None, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFreezeGraph(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ var = variable_scope.get_variable(
+ 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + var
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor], [out_tensor], freeze_variables=True)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testGraphviz(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.output_format = lite_constants.GRAPHVIZ_DOT
+ graphviz_output = converter.convert()
+ self.assertTrue(graphviz_output)
+
+
+class FromSavedModelTest(test_util.TensorFlowTestCase):
+
+ def _createSavedModel(self, shape):
+ """Create a simple SavedModel."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
+ with session.Session() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name='inputB')
+ in_tensor_2 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name='inputA')
+ out_tensor = in_tensor_1 + in_tensor_2
+ inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+ outputs = {'z': out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ return saved_model_dir
+
+ def testSimpleModel(self):
+ """Test a SavedModel."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testNoneBatchSize(self):
+ """Test a SavedModel, with None in input tensor's shape."""
+ saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
+
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testOrderInputArrays(self):
+ """Test a SavedModel ordering of input arrays."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ converter = lite.TocoConverter.from_saved_model(
+ saved_model_dir, input_arrays=['inputB', 'inputA'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testSubsetInputArrays(self):
+ """Test a SavedModel with a subset of the input array names of the model."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ # Check case where input shape is given.
+ converter = lite.TocoConverter.from_saved_model(
+ saved_model_dir,
+ input_arrays=['inputA'],
+ input_shapes={'inputA': [1, 16, 16, 3]})
+
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check case where input shape is None.
+ converter = lite.TocoConverter.from_saved_model(
+ saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
+
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
index ac408d2f94..64ab0a9fe2 100644
--- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
+++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
@@ -57,7 +57,6 @@ const char* kFileFooter =
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
-}
)";
} // anonymous namespace
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 75ac24719a..8cab6cd8cd 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -143,6 +143,7 @@ void TfLiteDriver::AllocateTensors() {
Invalidate("Failed to allocate tensors");
return;
}
+ ResetLSTMStateTensors();
must_allocate_tensors_ = false;
}
}
@@ -281,5 +282,24 @@ bool TfLiteDriver::CheckResults() {
return success;
}
+void TfLiteDriver::ResetLSTMStateTensors() {
+ // This is a workaround for initializing state tensors for LSTM.
+ // TODO(ycling): Refactoring and find a better way to initialize state
+ // tensors. Maybe write the reset instructions into the test data.
+ for (auto node_index : interpreter_->execution_plan()) {
+ const auto& node_and_reg = interpreter_->node_and_registration(node_index);
+ const auto& node = node_and_reg->first;
+ const auto& registration = node_and_reg->second;
+ if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
+ node.outputs->size >= 2) {
+ // The first 2 outputs of LSTM are state tensors.
+ for (int i = 0; i < 2; ++i) {
+ int node_index = node.outputs->data[i];
+ ResetTensor(node_index);
+ }
+ }
+ }
+}
+
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 02b7de1534..5493ba3631 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner {
string ReadOutput(int id) override { return "no-op"; }
private:
+ void ResetLSTMStateTensors();
+
class Expectation;
bool use_nnapi_ = false;
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 6e5927295f..3aeebb14f1 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -16,8 +16,6 @@ limitations under the License.
#include <cmath>
#include <memory>
-#include <set>
-#include <unordered_set>
#include <vector>
#include "absl/strings/str_replace.h"
@@ -304,7 +302,15 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
constexpr char kRNNBackEdgeFormat[] =
"\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
- std::set<string> already_added_arrays;
+ for (const auto& array_kv : model.GetArrayMap()) {
+ // Add node for array.
+ const string& array_name = array_kv.first;
+ const auto& array_properties = GetPropertiesForArray(model, array_name);
+ AppendF(output_file_contents, kNodeFormat, array_name,
+ array_properties.label, "octagon",
+ array_properties.color.FillColorString().c_str(),
+ array_properties.color.TextColorString().c_str());
+ }
for (int op_index = 0; op_index < model.operators.size(); op_index++) {
const Operator& op = *model.operators[op_index];
// Add node for operator.
@@ -313,20 +319,13 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
"box", op_properties.color.FillColorString().c_str(),
op_properties.color.TextColorString().c_str());
- // Add nodes and edges for all inputs of the operator.
+ // Add edges for all inputs of the operator.
for (const auto& input : op.inputs) {
if (!model.HasArray(input)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
auto array_properties = GetPropertiesForArray(model, input);
- if (!already_added_arrays.count(input)) {
- AppendF(output_file_contents, kNodeFormat, input,
- array_properties.label, "octagon",
- array_properties.color.FillColorString().c_str(),
- array_properties.color.TextColorString().c_str());
- }
-
// Draw lines that transport more data thicker (Otherwise, where would the
// data fit? right?).
float line_width =
@@ -342,22 +341,14 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
}
AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width,
weight);
- already_added_arrays.insert(input);
}
- // Add nodes and edges for all outputs of the operator.
+ // Add edges for all outputs of the operator.
for (const auto& output : op.outputs) {
if (!model.HasArray(output)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
auto array_properties = GetPropertiesForArray(model, output);
- if (!already_added_arrays.count(output)) {
- AppendF(output_file_contents, kNodeFormat, output,
- array_properties.label, "octagon",
- array_properties.color.FillColorString().c_str(),
- array_properties.color.TextColorString().c_str());
- }
-
// See comments above regarding weight and line_width calculations.
float line_width =
std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
@@ -367,7 +358,6 @@ void DumpGraphviz(const Model& model, string* output_file_contents) {
}
AppendF(output_file_contents, kEdgeFormat, operator_id, output,
line_width, weight);
- already_added_arrays.insert(output);
}
}
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index f0fd638a61..29a83bd26f 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -1,69 +1,198 @@
-# TensorFlow Lite Optimizing Converter (TOCO) Python API reference
+# TensorFlow Lite Optimizing Converter & Interpreter Python API reference
-This page provides examples on how to use TOCO via the Python API. It is
-complemented by the following documents:
+This page provides examples on how to use TOCO and the TensorFlow Lite
+interpreter via the Python API. It is complemented by the following documents:
* [README](../README.md)
* [Command-line examples](cmdline_examples.md)
* [Command-line glossary](cmdline_reference.md)
+Table of contents:
+
+* [High-level overview](#high-level-overview)
+* [API](#api)
+* [Basic examples](#basic)
+ * [Exporting a GraphDef with constants](#basic-graphdef-const)
+ * [Exporting a GraphDef with variables](#basic-graphdef-var)
+ * [Exporting a SavedModel](#basic-savedmodel)
+* [Complex examples](#complex)
+ * [Exporting a quantized GraphDef](#complex-quant)
+* [TensorFlow Lite Python interpreter](#interpreter)
+ * [Using the interpreter from a model file](#interpreter-file)
+ * [Using the interpreter from model data](#interpreter-data)
+
## High-level overview
While the TensorFlow Lite Optimizing Converter can be used from the command
-line, it is often convenient to use it as part of Python model build and
+line, it is often convenient to use it as part of a Python model build and
training script. This is so that conversion can be part of your model
development pipeline. This allows you to know early and often that you are
designing a model that can be targeted to devices with mobile.
## API
-In Python you can run `help(tf.contrib.lite)` to get documentation on functions.
-In particular, `tf.contrib.lite.toco_convert` presents a simple API and
-`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using
-the protobuf interface to TOCO.
+The API for converting TensorFlow models to TensorFlow Lite is
+`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
+`tf.contrib.lite.Interpreter`.
+
+`TocoConverter` provides class methods based on the original format of the
+model. `TocoConverter.from_session()` is available for GraphDefs.
+`TocoConverter.from_saved_model()` is available for SavedModels. Example usages
+for simple float-point models are shown in [Basic Examples](#basic). Examples
+usages for more complex models is shown in [Complex Examples](#complex).
+
+**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python
+interpreter when the conversion fails. This will be remedied as soon as
+possible.
+
+## Basic examples <a name="basic"></a>
-## Example
+The following section shows examples of how to convert a basic float-point model
+from each of the supported data formats into a TensorFlow Lite FlatBuffers.
-In particular, here we show creating a simple model and converting it to a
-TensorFlow Lite Model.
+### Exporting a GraphDef with constants <a name="basic-graphdef-const"></a>
+
+The following example shows how to convert a TensorFlow GraphDef with constants
+into a TensorFlow Lite FlatBuffer.
```python
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+val = img + const
out = tf.identity(val, name="out")
+
with tf.Session() as sess:
- tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
- open("test.tflite", "wb").write(tflite_model)
+ converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
```
-**NOTE** Currently, the TOCO command will cause a fatal error to the Python
-interpreter when TOCO conversion fails. This will be remedied as soon as
-possible.
-
-## Example 2: Export with variables
+### Exporting a GraphDef with variables <a name="basic-graphdef-var"></a>
-If a model has variables, they need to be turned into constants. This process is
-known as freezing, and it can actually be accomplished with
+If a model has variables, they need to be turned into constants through a
+process known as freezing. It can be accomplished by setting `freeze_variables`
+to `True` as shown in the example below.
```python
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
+var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + var
+out = tf.identity(val, name="out")
-def canonical_name(x):
- return x.name.split(":")[0]
+with tf.Session() as sess:
+ converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out],
+ freeze_variables=True)
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
+```
+
+### Exporting a SavedModel <a name="basic-savedmodel"></a>
+
+The following example shows how to convert a SavedModel into a TensorFlow Lite
+FlatBuffer.
+
+```python
+import tensorflow as tf
+
+converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
+For more complex SavedModels, the optional parameters that can be passed into
+`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`,
+`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are
+available by running `help(tf.contrib.lite.TocoConverter)`.
+
+## Complex examples <a name="complex"></a>
+
+For models where the default value of the attributes is not sufficient, the
+variables values should be set before calling `convert()`. In order to call any
+constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
+`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python
+terminal for detailed documentation on the attributes.
+
+Although the examples are demonstrated on GraphDefs containing only constants.
+The same logic can be applied irrespective of the input data format.
+
+### Exporting a quantized GraphDef <a name="complex-quant"></a>
+
+The following example shows how to convert a quantized model into a TensorFlow
+Lite FlatBuffer.
+
+```python
+import tensorflow as tf
+
+img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+val = img + const
+out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
-out = tf.identity(val, name="out")
with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- out_tensors = [out]
- frozen_graphdef = tf.graph_util.convert_variables_to_constants(
- sess, sess.graph_def, map(canonical_name, out_tensors))
- tflite_model = tf.contrib.lite.toco_convert(
- frozen_graphdef, [img], out_tensors)
+ converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
+ converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
+
+## TensorFlow Lite Python interpreter <a name="interpreter"></a>
+
+### Using the interpreter from a model file <a name="interpreter-file"></a>
+
+The following example shows how to use the TensorFlow Lite Python interpreter
+when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates
+how to run inference on random input data. Run
+`help(tf.contrib.lite.Interpreter)` in the Python terminal to get detailed
+documentation on the interpreter.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+# Load TFLite model and allocate tensors.
+interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
+interpreter.allocate_tensors()
+
+# Get input and output tensors.
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+
+# Test model on random input data.
+input_shape = input_details[0]['shape']
+input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
+interpreter.set_tensor(input_details[0]['index'], input_data)
+
+interpreter.invoke()
+output_data = interpreter.get_tensor(output_details[0]['index'])
+print(output_data)
+```
+
+### Using the interpreter from model data <a name="interpreter-data"></a>
+
+The following example shows how to use the TensorFlow Lite Python interpreter
+when starting with the TensorFlow Lite Flatbuffer model previously loaded. This
+example shows an end-to-end use case, starting from building the TensorFlow
+model.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+val = img + const
+out = tf.identity(val, name="out")
+
+with tf.Session() as sess:
+ converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
+ tflite_model = converter.convert()
+
+# Load TFLite model and allocate tensors.
+interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model)
+interpreter.allocate_tensors()
+```
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index f875c85d1a..0f104d5e2d 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -83,7 +83,7 @@ bool ParseModelFlagsFromCommandLineFlags(
"Deprecated: use --input_data_types instead. Input array type, if "
"not already provided in the graph. "
"Typically needs to be specified when passing arbitrary arrays "
- "to --input_array."),
+ "to --input_arrays."),
Flag("input_data_types", parsed_flags.input_data_types.bind(),
parsed_flags.input_data_types.default_value(),
"Input arrays types, comma-separated, if not already provided in "
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index 6c4f8e12cd..8cac568bd7 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -15,7 +15,7 @@ cc_library(
"//tensorflow/contrib/lite/toco:toco_port",
"//tensorflow/contrib/lite/toco:toco_tooling",
"//tensorflow/core:lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -26,7 +26,7 @@ tf_py_wrap_cc(
":toco_python_api",
"//tensorflow/contrib/lite/toco:model_flags_proto_cc",
"//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h
index 9af38e937c..7e8ad9c1da 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.h
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h
@@ -15,8 +15,8 @@ limitations under the License.
#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_
-#include <string>
#include <Python.h>
+#include <string>
namespace toco {
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
index 4d8d922cb4..5144f7c38c 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -171,8 +171,7 @@ class NcclManagerTest : public ::testing::Test {
private:
static Allocator* GpuAllocator(BaseGPUDevice* device) {
- return device->GetStepAllocator(AllocatorAttributes(),
- nullptr /* step_resource_manager */);
+ return device->GetAllocator(AllocatorAttributes());
}
static se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py
index bc92a7006f..21bf3f5313 100644
--- a/tensorflow/contrib/opt/python/training/adamax_test.py
+++ b/tensorflow/contrib/opt/python/training/adamax_test.py
@@ -198,11 +198,11 @@ class AdaMaxOptimizerTest(test.TestCase):
self.assertTrue(beta1_power is not None)
self.assertIn(beta1_power, opt_variables)
- with ops.Graph().as_default():
- # Shouldn't return non-slot variables from other graphs.
- self.assertEqual(0, len(opt.variables()))
-
if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py
index 26724f66c2..24cdab4626 100644
--- a/tensorflow/contrib/optimizer_v2/momentum_test.py
+++ b/tensorflow/contrib/optimizer_v2/momentum_test.py
@@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase):
with context.eager_mode():
self.doTestBasic(use_resource=True, use_callable_params=True)
- @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testVariablesAcrossGraphs(self):
optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5)
with ops.Graph().as_default():
@@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase):
[1.0, 2.0], dtype=dtypes.float32, name="var0")
var1 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var1")
- if context.executing_eagerly():
- loss = lambda: math_ops.reduce_sum(var0 + var1)
- else:
- loss = math_ops.reduce_sum(var0 + var1)
+ loss = math_ops.reduce_sum(var0 + var1)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var0")
@@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase):
[1.0, 2.0], dtype=dtypes.float32, name="var2")
var3 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var3")
- if context.executing_eagerly():
- loss = lambda: math_ops.reduce_sum(var2 + var3)
- else:
- loss = math_ops.reduce_sum(var2 + var3)
+ loss = math_ops.reduce_sum(var2 + var3)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var2")
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index fdecceff52..6bd58c4d32 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -1,4 +1,4 @@
-package(default_visibility = ["//tensorflow:__subpackages__"])
+package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index 630c0607ae..cfdc884277 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
+#include <deque>
+
#include "tensorflow/contrib/tensorboard/db/summary_converter.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -66,14 +68,9 @@ const char* kImagePluginName = "images";
const char* kAudioPluginName = "audio";
const char* kHistogramPluginName = "histograms";
-const int kScalarSlots = 10000;
-const int kImageSlots = 10;
-const int kAudioSlots = 10;
-const int kHistogramSlots = 1;
-const int kTensorSlots = 10;
-
const int64 kReserveMinBytes = 32;
const double kReserveMultiplier = 1.5;
+const int64 kPreallocateRows = 1000;
// Flush is a misnomer because what we're actually doing is having lots
// of commits inside any SqliteTransaction that writes potentially
@@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) {
}
}
-int GetSlots(const Tensor& t, const SummaryMetadata& metadata) {
- if (metadata.plugin_data().plugin_name() == kScalarPluginName) {
- return kScalarSlots;
- } else if (metadata.plugin_data().plugin_name() == kImagePluginName) {
- return kImageSlots;
- } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) {
- return kAudioSlots;
- } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) {
- return kHistogramSlots;
- } else if (t.dims() == 0 && t.dtype() != DT_STRING) {
- return kScalarSlots;
- } else {
- return kTensorSlots;
- }
-}
-
Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
const char* sql = R"sql(
INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
@@ -481,24 +462,6 @@ class RunMetadata {
return insert.StepAndReset();
}
- Status GetIsWatching(Sqlite* db, bool* is_watching)
- SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
- mutex_lock lock(mu_);
- if (experiment_id_ == kAbsent) {
- *is_watching = true;
- return Status::OK();
- }
- const char* sql = R"sql(
- SELECT is_watching FROM Experiments WHERE experiment_id = ?
- )sql";
- SqliteStatement stmt;
- TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
- stmt.BindInt(1, experiment_id_);
- TF_RETURN_IF_ERROR(stmt.StepOnce());
- *is_watching = stmt.ColumnInt(0) != 0;
- return Status::OK();
- }
-
private:
Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
@@ -659,43 +622,15 @@ class RunMetadata {
/// \brief Tensor writer for a single series, e.g. Tag.
///
-/// This class can be used to write an infinite stream of Tensors to the
-/// database in a fixed block of contiguous disk space. This is
-/// accomplished using Algorithm R reservoir sampling.
-///
-/// The reservoir consists of a fixed number of rows, which are inserted
-/// using ZEROBLOB upon receiving the first sample, which is used to
-/// predict how big the other ones are likely to be. This is done
-/// transactionally in a way that tries to be mindful of other processes
-/// that might be trying to access the same DB.
-///
-/// Once the reservoir fills up, rows are replaced at random, and writes
-/// gradually become no-ops. This allows long training to go fast
-/// without configuration. The exception is when someone is actually
-/// looking at TensorBoard. When that happens, the "keep last" behavior
-/// is turned on and Append() will always result in a write.
-///
-/// If no one is watching training, this class still holds on to the
-/// most recent "dangling" Tensor, so if Finish() is called, the most
-/// recent training state can be written to disk.
-///
-/// The randomly selected sampling points should be consistent across
-/// multiple instances.
-///
/// This class is thread safe.
class SeriesWriter {
public:
- SeriesWriter(int64 series, int slots, RunMetadata* meta)
- : series_{series},
- slots_{slots},
- meta_{meta},
- rng_{std::mt19937_64::default_seed} {
+ SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} {
DCHECK(series_ > 0);
- DCHECK(slots_ > 0);
}
Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
- Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
+ const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
if (rowids_.empty()) {
@@ -705,41 +640,20 @@ class SeriesWriter {
return s;
}
}
- DCHECK(rowids_.size() == slots_);
- int64 rowid;
- size_t i = count_;
- if (i < slots_) {
- rowid = last_rowid_ = rowids_[i];
- } else {
- i = rng_() % (i + 1);
- if (i < slots_) {
- rowid = last_rowid_ = rowids_[i];
- } else {
- bool keep_last;
- TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last));
- if (!keep_last) {
- ++count_;
- dangling_tensor_.reset(new Tensor(std::move(t)));
- dangling_step_ = step;
- dangling_computed_time_ = computed_time;
- return Status::OK();
- }
- rowid = last_rowid_;
- }
- }
+ int64 rowid = rowids_.front();
Status s = Write(db, rowid, step, computed_time, t);
if (s.ok()) {
++count_;
- dangling_tensor_.reset();
}
+ rowids_.pop_front();
return s;
}
Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
- // Short runs: Delete unused pre-allocated Tensors.
- if (count_ < rowids_.size()) {
+ // Delete unused pre-allocated Tensors.
+ if (!rowids_.empty()) {
SqliteTransaction txn(*db);
const char* sql = R"sql(
DELETE FROM Tensors WHERE rowid = ?
@@ -747,19 +661,13 @@ class SeriesWriter {
SqliteStatement deleter;
TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
for (size_t i = count_; i < rowids_.size(); ++i) {
- deleter.BindInt(1, rowids_[i]);
+ deleter.BindInt(1, rowids_.front());
TF_RETURN_IF_ERROR(deleter.StepAndReset());
+ rowids_.pop_front();
}
TF_RETURN_IF_ERROR(txn.Commit());
rowids_.clear();
}
- // Long runs: Make last sample be the very most recent one.
- if (dangling_tensor_) {
- DCHECK(last_rowid_ != kAbsent);
- TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_,
- dangling_computed_time_, *dangling_tensor_));
- dangling_tensor_.reset();
- }
return Status::OK();
}
@@ -783,7 +691,6 @@ class SeriesWriter {
Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
const StringPiece& data, int64 rowid) {
- // TODO(jart): How can we ensure reservoir fills on replace?
const char* sql = R"sql(
UPDATE OR REPLACE
Tensors
@@ -878,7 +785,7 @@ class SeriesWriter {
// TODO(jart): Maybe preallocate index pages by setting step. This
// is tricky because UPDATE OR REPLACE can have a side
// effect of deleting preallocated rows.
- for (int64 i = 0; i < slots_; ++i) {
+ for (int64 i = 0; i < kPreallocateRows; ++i) {
insert.BindInt(1, series_);
insert.BindInt(2, reserved_bytes);
TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
@@ -902,16 +809,10 @@ class SeriesWriter {
mutex mu_;
const int64 series_;
- const int slots_;
RunMetadata* const meta_;
- std::mt19937_64 rng_ GUARDED_BY(mu_);
uint64 count_ GUARDED_BY(mu_) = 0;
- int64 last_rowid_ GUARDED_BY(mu_) = kAbsent;
- std::vector<int64> rowids_ GUARDED_BY(mu_);
+ std::deque<int64> rowids_ GUARDED_BY(mu_);
uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
- std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_);
- int64 dangling_step_ GUARDED_BY(mu_) = 0;
- double dangling_computed_time_ GUARDED_BY(mu_) = 0.0;
TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
};
@@ -928,10 +829,10 @@ class RunWriter {
explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
- double computed_time, Tensor t, int slots)
+ double computed_time, const Tensor& t)
SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
- SeriesWriter* writer = GetSeriesWriter(tag_id, slots);
- return writer->Append(db, step, now, computed_time, std::move(t));
+ SeriesWriter* writer = GetSeriesWriter(tag_id);
+ return writer->Append(db, step, now, computed_time, t);
}
Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
@@ -948,11 +849,11 @@ class RunWriter {
}
private:
- SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) {
+ SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) {
mutex_lock sl(mu_);
auto spot = series_writers_.find(tag_id);
if (spot == series_writers_.end()) {
- SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_);
+ SeriesWriter* writer = new SeriesWriter(tag_id, meta_);
series_writers_[tag_id].reset(writer);
return writer;
} else {
@@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
TF_RETURN_IF_ERROR(
meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
TF_RETURN_WITH_CONTEXT_IF_ERROR(
- run_.Append(db_, tag_id, step, now, computed_time, t,
- GetSlots(t, metadata)),
+ run_.Append(db_, tag_id, step, now, computed_time, t),
meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
"/", tag, "@", step);
return Status::OK();
@@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
int64 tag_id;
TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
&tag_id, s->metadata()));
- return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t,
- GetSlots(t, s->metadata()));
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
}
// TODO(jart): Refactor Summary -> Tensor logic into separate file.
@@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
PatchPluginName(s->mutable_metadata(), kScalarPluginName);
TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
&tag_id, s->metadata()));
- return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
- std::move(t), kScalarSlots);
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
}
Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
@@ -1201,8 +1099,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
&tag_id, s->metadata()));
- return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
- std::move(t), kHistogramSlots);
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
}
Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
@@ -1216,8 +1113,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
PatchPluginName(s->mutable_metadata(), kImagePluginName);
TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
&tag_id, s->metadata()));
- return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
- std::move(t), kImageSlots);
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
}
Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
@@ -1230,8 +1126,7 @@ class SummaryDbWriter : public SummaryWriterInterface {
PatchPluginName(s->mutable_metadata(), kAudioPluginName);
TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
&tag_id, s->metadata()));
- return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
- std::move(t), kAudioSlots);
+ return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
}
Env* const env_;
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 2044692b6e..2e8d4109dd 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -189,7 +189,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
int64 user_id = QueryInt("SELECT user_id FROM Users");
int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
@@ -238,7 +238,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
}
TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
@@ -255,7 +255,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
TF_ASSERT_OK(writer_->Flush());
ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
- ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+ ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'Ï€'");
int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
EXPECT_GT(tag1_id, 0LL);
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 7a8a71ac7f..15811ce0e3 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -303,7 +303,7 @@ tf_cuda_library(
],
deps = [
"//tensorflow/core:framework_lite",
- "//tensorflow/core:platform_base",
+ "//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]),
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 32b211dcd1..96e0700862 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -2534,7 +2534,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
// Build the TRT op
// TODO(sami,ben,jie): proper naming!
tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp");
- SetInputList(s, &op_builder, &input_names, &input_dtypes);
+ TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
std::vector<string> segment_names;
segment_names.reserve(s.subgraph_node_ids.size());
@@ -2632,7 +2632,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
// Build the TRT op
tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
- SetInputList(s, &op_builder, &input_names, &input_dtypes);
+ TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
VLOG(0) << "Finished op preparation";
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 816897499b..99485322c6 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -79,7 +79,9 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
request.set_repository_root(repository_root);
request.set_session_id(session_id);
}
+ request.add_tools("op_profile");
request.add_tools("input_pipeline");
+ request.add_tools("memory_viewer");
request.add_tools("overview_page");
*request.mutable_opts() = opts;
std::cout << "Limiting the number of trace events to " << kMaxEvents
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index b9ac1a550c..2b13343efa 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -87,6 +87,8 @@ message StepInfoResult {
optional uint64 wait_duration_ps = 5;
// The time spent on cross-replica-sum in picoseconds.
optional uint64 crs_duration_ps = 6;
+ // Percentage of unit b time spent on infeed.
+ optional double unit_b_infeed_percent = 7;
}
// Result proto for a sequence of steps.
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
index 7be694e866..f0fca63db0 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -68,7 +68,8 @@ message ProfileRequest {
}
message ProfileToolData {
- // The tool's name which this data is associated. (e.g. "input_pipeline".)
+ // The file name which this data is associated (e.g. "input_pipeline.json",
+ // "cluster_xxx.memory_viewer.json").
string name = 1;
// The data payload (likely json) for the specific tool.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index e2f57ce9c5..7d165fdd6e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
@@ -329,6 +330,7 @@ def outside_compilation(computation, args=None):
Returns:
The Tensors returned by computation.
"""
+ args = [] if args is None else args
graph = ops.get_default_graph()
# If we are in a TPUReplicateContext, signal that we are now
@@ -867,3 +869,152 @@ def rewrite(computation,
device_assignment=device_assignment,
name=name)[0]
# pylint: enable=indexing-exception
+
+ # Operations that indicate some error in the user's inference graph.
+_BLACKLISTED_INFERENCE_OPS = set([
+ "ReadVariableOp",
+ "AssignVariableOp",
+ "AssignAddVariableOp",
+ "AssignSubVariableOp",
+ "VarHandleOp",
+ "Variable",
+ "VariableV2",
+])
+
+
+class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
+ """A `ControlFlowContext` for nodes inside a TPU inference computation.
+
+ The primary role of `TPUReplicateContext` is to sanity check operators inside
+ a tpu.rewrite_for_inference() computation.
+ """
+
+ def __init__(self, name):
+ super(_TPUInferenceContext, self).__init__()
+ self._name = name
+
+ def AddOp(self, op):
+ self._AddOpInternal(op)
+
+ def _AddOpInternal(self, op):
+ # pylint: disable=protected-access
+ if op.type in _BLACKLISTED_INFERENCE_OPS:
+ raise NotImplementedError(
+ "Operation of type %s (%s) is not supported on the TPU for inference."
+ " Execution will fail if this op is used in the graph. Make sure your"
+ " variables are using variable_scope." % (op.type, op.name))
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ result = val
+ if self._outer_context:
+ result = self._outer_context.AddValue(val)
+ return result
+
+ def AddInnerOp(self, op):
+ self._AddOpInternal(op)
+
+ @property
+ def grad_state(self):
+ return None
+
+
+@experimental
+def validate_inference_rewrite_for_variables(graph):
+ """Validates whether rewrite_for_inference() 'worked' for variables.
+
+ The rewrite_for_inference() method is supposed to append
+ GuaranteeConstOps after ReadVariableOps, but this mechanism works only
+ if you are using tf.get_variable() to create and access variables in your
+ tpu computation. This validation method can be called immediately after
+ calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
+ where added to the graph.
+
+ Typical usages:
+ tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
+
+ tpu.validate_inference_rewrite_for_variables(sess.graph)
+
+ Args:
+ graph: The graph which needs to be validated.
+ Raises:
+ RuntimeError: if validation failed.
+ """
+ if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
+ raise RuntimeError(
+ "No GuaranteeConst ops found in the graph after "
+ "running tpu.rewrite_for_inference(...). Please "
+ "check that you are using tf.get_variable() to "
+ "create and access variables in your tpu "
+ "computation.")
+
+
+@experimental
+def rewrite_for_inference(computation,
+ inputs=None,
+ infeed_queue=None,
+ device_assignment=None,
+ name=None):
+ """Rewrites `computation` for inference on a TPU system.
+
+ Other than 'rewriting' the computation to run on a TPU, if using variables
+ in your computation, it moves the ReadVariableOps outside the TPU
+ computation, and adds GuaranteeConst ops just after the ReadVariableOps.
+ This mechanism works only if you are using tf.get_variable() to create and
+ access variables in your tpu computation. You can validate whether
+ this worked, by calling validate_inference_rewrite_for_variables() method
+ immediately after this method to check whether GuaranteeConstOps where
+ added to the graph.
+
+ Args:
+ computation: A Python function that builds a computation to apply
+ to the input. If the function takes n inputs, 'inputs' should be
+ a list of n tensors. If the function returns m outputs, rewrite
+ will return a list of m tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+ infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
+ of arguments as inputs to `computation`.
+ device_assignment: if not `None`, a `DeviceAssignment` describing the
+ mapping between logical cores in the computation with physical cores in
+ the TPU topology. May be omitted for a single-core computation, in which
+ case the core attached to task 0, TPU device 0 is used.
+ name: The name of the operator.
+ Returns:
+ A list of output tensors.
+ """
+
+ def guarantee_const_getter(getter, name, *args, **kwargs):
+ with ops.control_dependencies(None):
+ return array_ops.guarantee_const(
+ getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
+
+ def wrapped_computation(*args, **kwargs):
+ """Execute computation under `_TPUInferenceContext`."""
+ context = _TPUInferenceContext(
+ name=ops.get_default_graph().unique_name("rewrite_for_inference"))
+ try:
+ context.Enter()
+
+ vscope = variable_scope.get_variable_scope()
+ prev_custom_getter = vscope.custom_getter
+ prev_caching_device = vscope.caching_device
+ vscope.set_custom_getter(guarantee_const_getter)
+ vscope.set_caching_device(lambda op: op.device)
+
+ result = computation(*args, **kwargs)
+
+ vscope.set_custom_getter(prev_custom_getter)
+ vscope.set_caching_device(prev_caching_device)
+ finally:
+ context.Exit()
+ return result
+
+ # pylint: disable=undefined-variable
+ return rewrite(
+ wrapped_computation,
+ inputs=inputs,
+ infeed_queue=infeed_queue,
+ device_assignment=device_assignment,
+ name=name)
+ # pylint: enable=undefined-variable
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 77d117ba78..f27375637a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -61,6 +62,7 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import evaluation
@@ -71,6 +73,7 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
+
_INITIAL_LOSS = 1e7
_ZERO_LOSS = 0.
_TPU_ESTIMATOR = 'tpu_estimator'
@@ -81,6 +84,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
+_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
@@ -1264,13 +1268,11 @@ class _ModelFnWrapper(object):
'estimator_spec used by TPU prediction must have type'
'`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
+ self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
+
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
to_record = {}
identity_fn = lambda **kwargs: kwargs
- # TODO(xiejw): Adds validation for prediction dictionrary.
- # TODO(xiejw): Adds support for single tensor as predictions.
- if not isinstance(tpu_estimator_spec.predictions, dict):
- raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
to_record['signals'] = [identity_fn, stopping_signals]
if tpu_estimator_spec.host_call is not None:
@@ -1282,6 +1284,21 @@ class _ModelFnWrapper(object):
return predict_step, host_calls, captured_scaffold_fn
+ def _verify_tpu_spec_predictions(self, predictions):
+ """Validates TPUEstimatorSpec.predictions dict."""
+ # TODO(xiejw): Adds validation for prediction dictionrary.
+ # TODO(xiejw): Adds support for single tensor as predictions.
+ if not isinstance(predictions, dict):
+ raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
+
+ for (key, tensor) in predictions.items():
+ if tensor.shape[0].value is None:
+ raise ValueError(
+ 'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
+ 'dynamic shape (should be static). Tensor: {}'.format(
+ key, tensor))
+ return predictions
+
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
model_fn_args = function_utils.fn_args(self._model_fn)
@@ -1760,8 +1777,45 @@ class TPUEstimator(estimator_lib.Estimator):
Exporting
=========
- Exporting `SavedModel` support on TPU is not yet implemented. So,
- `export_savedmodel` is executed on CPU, even if `use_tpu` is true.
+ `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
+ and another with `tag_constants.SERVING` and `tag_constants.TPU`.
+ At serving time, these tags are used to select metagraph to load.
+
+ Before running the graph on TPU, TPU system needs to be initialized. If
+ TensorFlow Serving model-server is used, this is done automatically. If
+ not, please call `session.run(tpu.initialize_system())`.
+
+ `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
+ `model_fn`.
+
+ Example:
+ ----------------
+
+ ```
+ def model_fn(features, labels, mode, config, params):
+ ...
+ logits = ...
+ export_outputs = {
+ 'logits': export_output_lib.PredictOutput(
+ {'logits': logits})
+ }
+
+ def host_call(logits):
+ class_ids = math_ops.argmax(logits)
+ classes = string_ops.as_string(class_ids)
+ export_outputs['classes'] =
+ export_output_lib.ClassificationOutput(classes=classes)
+
+ tpu.outside_compilation(host_call, [logits])
+
+ ...
+ ```
+
+ Current limitations:
+ --------------------
+
+ 1. Outside compilation does not work yet (b/79991729).
+
"""
def __init__(self,
@@ -1890,6 +1944,103 @@ class TPUEstimator(estimator_lib.Estimator):
self._is_input_fn_invoked = None
+ def _add_meta_graph_for_mode(self,
+ builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=True,
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ export_tags=None):
+ if mode != model_fn_lib.ModeKeys.PREDICT:
+ raise NotImplementedError(
+ 'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
+ 'got {}.'.format(mode))
+
+ super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables,
+ mode=mode)
+
+ input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
+ input_receiver_fn_map[mode]}
+ export_tags = [tag_constants.SERVING, tag_constants.TPU]
+ mode = _REWRITE_FOR_INFERENCE_MODE
+ super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=False,
+ mode=mode,
+ export_tags=export_tags)
+
+ def _call_model_fn(self, features, labels, mode, config):
+ if mode == _REWRITE_FOR_INFERENCE_MODE:
+ return self._call_model_fn_for_inference(features, labels, mode, config)
+ else:
+ return super(TPUEstimator, self)._call_model_fn(
+ features, labels, mode, config)
+
+ def _call_model_fn_for_inference(self, features, labels, mode, config):
+ """Wraps `_call_model_fn` for `export_savedmodel`."""
+ if mode != _REWRITE_FOR_INFERENCE_MODE:
+ raise ValueError('mode must be {}; '
+ 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
+
+ capture = _CapturedObject()
+
+ def computation():
+ """Compute tpu tensors used in export_outputs.
+
+ Passed to rewrite_for_inference so that model_fn will be called under
+ the rewriting contexts. Only tpu tensors are returned, but export_outputs
+ and scaffold are captured.
+
+ Returns:
+ A list of Tensors used in export_outputs and not marked for
+ outside_compilation.
+ """
+ # We should only call model fn once and it should be inside `computation`
+ # so that building the graph will happen under `rewrite_for_inference`.
+ mode = model_fn_lib.ModeKeys.PREDICT
+ estimator_spec = self._call_model_fn(features, labels, mode, config)
+
+ # We pick the TPU tensors out from `export_output` and later return them
+ # from `computation` for rewriting.
+ tensors_dict = collections.OrderedDict(
+ (k, _export_output_to_tensors(v))
+ for k, v in six.iteritems(estimator_spec.export_outputs)
+ )
+ tensors = nest.flatten(tensors_dict)
+ tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]
+
+ # We cannot return anything other than `tpu_tensors` here so we capture
+ # the rest for later use.
+ capture.capture((estimator_spec, tensors_dict, tensors))
+ return tpu_tensors
+
+ tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation)
+ estimator_spec, tensors_dict, tensors = capture.get()
+
+ # Reconstruct `tensors`, but with `tpu_tensors` replaced with
+ # `tpu_tensors_on_cpu`.
+ new_tensors = [
+ tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t
+ for t in tensors
+ ]
+ # Reconstruct `tensors_dict`.
+ new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
+ # Reconstruct `export_outputs`.
+ export_outputs = estimator_spec.export_outputs
+ new_export_outputs = collections.OrderedDict(
+ (k, _clone_export_output_with_tensors(export_outputs[k], v))
+ for k, v in six.iteritems(new_tensors_dict)
+ )
+
+ return estimator_spec._replace(export_outputs=new_export_outputs)
+
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
@@ -2265,6 +2416,76 @@ class TPUEstimator(estimator_lib.Estimator):
return _model_fn
+def _is_tpu_tensor(tensor):
+ if not isinstance(tensor, ops.Tensor):
+ return False
+ try:
+ tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access
+ except ValueError:
+ return True
+ else:
+ return False
+
+
+def _export_output_to_tensors(export_output):
+ """Get a list of `Tensors` used in `export_output`.
+
+ Args:
+ export_output: an `ExportOutput` object such as `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ Returns:
+ a list of tensors used in export_output.
+
+ Raises:
+ ValueError: if `export_output` is not one of `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ """
+ if isinstance(export_output, export_output_lib.ClassificationOutput):
+ return [export_output.scores, export_output.classes]
+ elif isinstance(export_output, export_output_lib.RegressionOutput):
+ return [export_output.value]
+ elif isinstance(export_output, export_output_lib.PredictOutput):
+ return export_output.outputs.values()
+ else:
+ raise ValueError(
+ '`export_output` must be have type `ClassificationOutput`, '
+ '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
+def _clone_export_output_with_tensors(export_output, tensors):
+ """Clones `export_output` but with new `tensors`.
+
+ Args:
+ export_output: an `ExportOutput` object such as `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ tensors: a list of `Tensors` used to construct a new `export_output`.
+
+ Returns:
+ A dict similar to `export_output` but with `tensors`.
+
+ Raises:
+ ValueError: if `export_output` is not one of `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+ """
+ if isinstance(export_output, export_output_lib.ClassificationOutput):
+ if len(tensors) != 2:
+ raise ValueError('tensors must be of length 2; '
+ 'got {}.'.format(len(tensors)))
+ return export_output_lib.ClassificationOutput(*tensors)
+ elif isinstance(export_output, export_output_lib.RegressionOutput):
+ if len(tensors) != 1:
+ raise ValueError('tensors must be of length 1; '
+ 'got {}'.format(len(tensors)))
+ return export_output_lib.RegressionOutput(*tensors)
+ elif isinstance(export_output, export_output_lib.PredictOutput):
+ return export_output_lib.PredictOutput(
+ dict(zip(export_output.outputs.keys(), tensors)))
+ else:
+ raise ValueError(
+ '`export_output` must be have type `ClassificationOutput`, '
+ '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
+
+
def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
iterations_per_loop_var = _create_or_get_iterations_per_loop()
@@ -2831,4 +3052,3 @@ def _add_item_to_params(params, key, value):
else:
# Now params is Python dict.
params[key] = value
-
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index a576e36097..10109e5ac1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -111,8 +111,6 @@ load(
"tf_additional_lib_deps",
"tf_additional_lib_hdrs",
"tf_additional_lib_srcs",
- "tf_additional_framework_hdrs",
- "tf_additional_framework_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
@@ -295,43 +293,18 @@ cc_library(
],
)
-PLATFORM_BASE_HDRS = [
- "platform/env_time.h",
- "platform/logging.h",
- "platform/macros.h",
- "platform/types.h",
- "platform/byte_order.h",
-]
-
-PLATFORM_OTHER_HDRS = [
- "platform/abi.h",
- "platform/stacktrace.h",
- "platform/stacktrace_handler.h",
- "platform/context.h",
- "platform/cpu_info.h",
- "platform/cpu_feature_guard.h",
- "platform/dynamic_annotations.h",
- "platform/error.h",
- "platform/env.h",
- "platform/file_system.h",
- "platform/file_system_helper.h",
- "platform/fingerprint.h",
- "platform/init_main.h",
- "platform/mem.h",
- "platform/mutex.h",
- "platform/net.h",
- "platform/notification.h",
- "platform/null_file_system.h",
- "platform/prefetch.h",
- "platform/profile_utils/clock_cycle_profiler.h",
- "platform/profile_utils/cpu_utils.h",
- "platform/protobuf.h",
- "platform/strong_hash.h",
- "platform/subprocess.h",
- "platform/thread_annotations.h",
-]
+filegroup(
+ name = "platform_base_hdrs",
+ srcs = [
+ "platform/byte_order.h",
+ "platform/env_time.h",
+ "platform/logging.h",
+ "platform/macros.h",
+ "platform/types.h",
+ ],
+ visibility = ["//visibility:private"],
+)
-# Smaller platform libraries that don't depend on "lib" or "lib_internal".
cc_library(
name = "platform_base",
srcs = tf_platform_hdrs([
@@ -343,16 +316,262 @@ cc_library(
]) + [
"platform/env_time.cc",
],
- hdrs = PLATFORM_BASE_HDRS,
+ hdrs = [":platform_base_hdrs"],
copts = tf_copts(),
- # TODO(ahentz): remove use of this library so we can move it into 'platform'
tags = ["avoid_dep"],
+ visibility = ["//tensorflow/core:__subpackages__"],
deps = [
":lib_platform",
"//tensorflow/core/platform/default/build_config:base",
],
)
+filegroup(
+ name = "platform_port_hdrs",
+ srcs = [
+ "platform/cpu_info.h",
+ "platform/dynamic_annotations.h",
+ "platform/init_main.h",
+ "platform/mem.h",
+ "platform/mutex.h",
+ "platform/thread_annotations.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+# Headers that are not exported as part of ":lib".
+filegroup(
+ name = "platform_port_internal_hdrs",
+ srcs = [
+ "platform/demangle.h",
+ "platform/host_info.h",
+ "platform/snappy.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "platform_port",
+ srcs = tf_platform_hdrs([
+ "cpu_info.h",
+ "dynamic_annotations.h",
+ "thread_annotations.h",
+ "mutex.h",
+ ]) + tf_platform_srcs([
+ "port.cc",
+ ]) + [
+ "platform/cpu_info.cc",
+ ],
+ hdrs = [
+ ":platform_port_hdrs",
+ ":platform_port_internal_hdrs",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow/core:__subpackages__"],
+ deps = [
+ ":lib_platform",
+ ":platform_base",
+ "//tensorflow/core/platform/default/build_config:port",
+ "@snappy",
+ ],
+)
+
+filegroup(
+ name = "platform_protobuf_hdrs",
+ srcs = [
+ "platform/protobuf.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+# Headers that are not exported as part of ":lib".
+filegroup(
+ name = "platform_protobuf_internal_hdrs",
+ srcs = [
+ "platform/protobuf_internal.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "platform_protobuf",
+ srcs = tf_platform_hdrs([
+ "protobuf.h",
+ ]) + tf_platform_srcs([
+ "protobuf.cc",
+ ]) + [
+ "platform/protobuf_util.cc",
+ ],
+ hdrs = [
+ ":platform_protobuf_hdrs",
+ ":platform_protobuf_internal_hdrs",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow/core:__subpackages__"],
+ deps = [
+ ":lib_platform",
+ ":platform_base",
+ ":platform_port",
+ "//tensorflow/core/platform/default/build_config:protobuf",
+ "@protobuf_archive//:protobuf",
+ ],
+)
+
+filegroup(
+ name = "platform_env_hdrs",
+ srcs = [
+ "platform/env.h",
+ "platform/file_statistics.h",
+ "platform/file_system.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+# Headers that are not exported as part of ":lib".
+filegroup(
+ name = "platform_env_internal_hdrs",
+ srcs = [
+ "platform/load_library.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "platform_env",
+ srcs = tf_platform_srcs([
+ "env.cc",
+ "load_library.cc",
+ ]) + tf_platform_hdrs([
+ "wide_char.h",
+ ]) + [
+ "platform/env.cc",
+ "platform/file_system.cc",
+ ],
+ hdrs = [
+ ":platform_env_hdrs",
+ ":platform_env_internal_hdrs",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow/core:__subpackages__"],
+ deps = [
+ ":error_codes_proto_cc",
+ ":lib",
+ ":lib_internal",
+ ":lib_platform",
+ ":platform_base",
+ ":platform_port",
+ ":platform_protobuf",
+ "//tensorflow/core/platform/default/build_config:env",
+ ],
+)
+
+filegroup(
+ name = "platform_file_system_hdrs",
+ srcs = [
+ "platform/file_system_helper.h",
+ "platform/null_file_system.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "platform_file_system",
+ srcs = tf_platform_srcs([
+ ]) + tf_platform_hdrs([
+ "windows_file_system.h",
+ ]) + [
+ "platform/file_system_helper.cc",
+ ],
+ hdrs = [
+ ":platform_file_system_hdrs",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow/core:__subpackages__"],
+ deps = [
+ ":lib",
+ ":lib_platform",
+ ":platform_env",
+ ],
+)
+
+filegroup(
+ name = "platform_other_hdrs",
+ srcs = [
+ "platform/abi.h",
+ "platform/context.h",
+ "platform/cpu_feature_guard.h",
+ "platform/error.h",
+ "platform/fingerprint.h",
+ "platform/net.h",
+ "platform/notification.h",
+ "platform/prefetch.h",
+ "platform/profile_utils/android_armv7a_cpu_utils_helper.h",
+ "platform/profile_utils/clock_cycle_profiler.h",
+ "platform/profile_utils/cpu_utils.h",
+ "platform/profile_utils/i_cpu_utils_helper.h",
+ "platform/stacktrace.h",
+ "platform/stacktrace_handler.h",
+ "platform/strong_hash.h",
+ "platform/subprocess.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+# Headers that are not exported as part of ":lib".
+filegroup(
+ name = "platform_other_internal_hdrs",
+ srcs = [
+ "platform/denormal.h",
+ "platform/setround.h",
+ "platform/tracing.h",
+ ],
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "platform_other",
+ srcs = tf_platform_srcs([
+ "subprocess.cc",
+ "net.cc",
+ "tracing.cc",
+ ]) + tf_platform_hdrs([
+ "tracing.h",
+ "error.h",
+ "context.h",
+ "fingerprint.h",
+ "notification.h",
+ "stacktrace.h",
+ "strong_hash.h",
+ "subprocess.h",
+ "tracing_impl.h",
+ ]) + [
+ "platform/cpu_feature_guard.cc",
+ "platform/setround.cc",
+ "platform/tracing.cc",
+ "platform/denormal.cc",
+ "platform/profile_utils/android_armv7a_cpu_utils_helper.cc",
+ "platform/profile_utils/clock_cycle_profiler.cc",
+ "platform/profile_utils/cpu_utils.cc",
+ ],
+ hdrs = [
+ ":platform_other_hdrs",
+ ":platform_other_internal_hdrs",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow/core:__subpackages__"],
+ deps = [
+ ":lib",
+ ":lib_platform",
+ ":platform_base",
+ ":platform_env",
+ ":platform_port",
+ ":platform_protobuf",
+ "//tensorflow/core/platform/default/build_config:other",
+ "//tensorflow/core/platform/default/build_config:platformlib",
+ "//tensorflow/core/platform/default/build_config:port",
+ ],
+)
+
# Minimal lib so that tools used for mobile compilation
# don't have to depend on lib/platformlib.
cc_library(
@@ -386,8 +605,7 @@ cc_library(
# tf_cc_test and tf_cc_binary will include the necessary symbols.
cc_library(
name = "lib",
- hdrs = PLATFORM_BASE_HDRS +
- PLATFORM_OTHER_HDRS + [
+ hdrs = [
"lib/bfloat16/bfloat16.h",
"lib/core/arena.h",
"lib/core/bitmap.h",
@@ -434,6 +652,12 @@ cc_library(
"lib/strings/str_util.h",
"lib/strings/strcat.h",
"lib/strings/stringprintf.h",
+ ":platform_base_hdrs",
+ ":platform_env_hdrs",
+ ":platform_file_system_hdrs",
+ ":platform_other_hdrs",
+ ":platform_port_hdrs",
+ ":platform_protobuf_hdrs",
],
visibility = ["//visibility:public"],
deps = [
@@ -607,6 +831,8 @@ tf_cuda_library(
"util/sparse/group_iterator.h",
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
+ "util/stat_summarizer_options.h",
+ "util/stats_calculator.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -632,6 +858,16 @@ tf_cuda_library(
)
cc_library(
+ name = "stats_calculator_portable",
+ srcs = ["util/stats_calculator.cc"],
+ hdrs = [
+ "util/stat_summarizer_options.h",
+ "util/stats_calculator.h",
+ ],
+ deps = [":platform_base"],
+)
+
+cc_library(
name = "overflow",
hdrs = ["util/overflow.h"],
deps = [
@@ -1111,6 +1347,7 @@ cc_library(
":shape_inference_testutil",
":tensor_testutil",
":test",
+ ":testlib_ops",
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:ops_testutil",
@@ -1118,6 +1355,18 @@ cc_library(
],
)
+cc_library(
+ name = "testlib_ops",
+ testonly = 1,
+ srcs = ["common_runtime/testlib_ops.cc"],
+ linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
# This is a link-only library to provide a DirectSession
# implementation of the Session interface.
tf_cuda_library(
@@ -1768,8 +2017,6 @@ cc_library(
"platform/**/device_tracer.cc",
"platform/**/logging.cc",
"platform/abi.cc",
- "platform/variant_coding.cc",
- "platform/**/variant_cord_coding.cc",
],
) + tf_additional_lib_srcs(
exclude = [
@@ -1782,8 +2029,6 @@ cc_library(
"platform/**/device_tracer.cc",
"platform/**/logging.cc",
"platform/abi.cc",
- "platform/variant_coding.cc",
- "platform/**/variant_cord_coding.cc",
] +
# Protobuf deps already included through the ":lib_proto_parsing"
# dependency.
@@ -2033,7 +2278,6 @@ cc_library(
)
FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [
- "platform/variant_coding.h",
"graph/edgeset.h",
"graph/graph.h",
"graph/graph_def_builder.h",
@@ -2074,14 +2318,13 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
"framework/variant.h",
- "platform/variant_coding.h",
"util/command_line_flags.h",
"util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
"util/tensor_slice_util.h",
-] + tf_additional_framework_hdrs()
+]
tf_cuda_library(
name = "framework_internal",
@@ -2123,9 +2366,7 @@ cc_header_only_library(
tf_cuda_library(
name = "framework_internal_impl",
- srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + [
- "platform/variant_coding.cc",
- ] + glob(
+ srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + glob(
[
"example/**/*.cc",
"framework/**/*.cc",
@@ -2159,7 +2400,7 @@ tf_cuda_library(
"util/memmapped_file_system.cc",
"util/memmapped_file_system_writer.cc",
],
- }) + tf_additional_framework_srcs(),
+ }),
hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(),
linkopts = select({
@@ -2524,6 +2765,7 @@ cc_library(
],
visibility = [
"//tensorflow/compiler:__subpackages__",
+ "//tensorflow/core/kernels:__subpackages__",
"//tensorflow/core/profiler:__subpackages__",
],
deps = [":lib_internal"],
@@ -3761,6 +4003,31 @@ tf_cc_test(
)
tf_cc_test(
+ name = "common_runtime_executor_test",
+ size = "small",
+ srcs = ["common_runtime/executor_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":core",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":lib_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/core/kernels:array",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:math",
+ "//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:state",
+ ],
+)
+
+tf_cc_test(
name = "common_runtime_function_test",
size = "small",
srcs = ["common_runtime/function_test.cc"],
diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt
new file mode 100644
index 0000000000..d8c2ed40a3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "AnonymousIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+A handle to the iterator that can be passed to a "MakeIterator" or
+"IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents
+resource sharing by name, and does not keep a reference to the resource
+container.
+END
+ }
+ summary: "A container for an iterator resource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
index 88049bca36..988bf0a0f8 100644
--- a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
@@ -1,5 +1,5 @@
op {
graph_op_name: "CollectiveBcastRecv"
- visibility: SKIP
summary: "Receives a tensor value broadcast from another device."
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
index 7ff70f5b17..d212f6dce7 100644
--- a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
@@ -1,5 +1,5 @@
op {
graph_op_name: "CollectiveBcastSend"
- visibility: SKIP
summary: "Broadcasts a tensor value to one or more other devices."
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
index 10d9771d46..fdd9443ba5 100644
--- a/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
@@ -1,5 +1,5 @@
op {
graph_op_name: "CollectiveReduce"
- visibility: SKIP
summary: "Mutually reduces multiple tensors of identical type and shape."
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
index ca7e0d3bee..d13866ddaa 100644
--- a/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
@@ -38,7 +38,9 @@ END
Computes the string join across dimensions in the given string Tensor of shape
`[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input
strings with the given separator (default: empty string). Negative indices are
-counted backwards from the end, with `-1` being equivalent to `n - 1`.
+counted backwards from the end, with `-1` being equivalent to `n - 1`. If
+indices are not specified, joins across all dimensions beginning from `n - 1`
+through `0`.
For example:
@@ -51,9 +53,10 @@ tf.reduce_join(a, -1) = tf.reduce_join(a, 1) ==> ["ab", "cd"]
tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]]
tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]]
tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"]
-tf.reduce_join(a, [0, 1]) ==> ["acbd"]
-tf.reduce_join(a, [1, 0]) ==> ["abcd"]
-tf.reduce_join(a, []) ==> ["abcd"]
+tf.reduce_join(a, [0, 1]) ==> "acbd"
+tf.reduce_join(a, [1, 0]) ==> "abcd"
+tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]]
+tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd"
```
END
}
diff --git a/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt
new file mode 100644
index 0000000000..98b7def4d6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "AnonymousIterator"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CollectiveBcastRecv.pbtxt b/tensorflow/core/api_def/python_api/api_def_CollectiveBcastRecv.pbtxt
new file mode 100644
index 0000000000..78034ccffd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CollectiveBcastRecv.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "CollectiveBcastRecv"
+ endpoint {
+ name: "collective.broadcast_recv"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CollectiveBcastSend.pbtxt b/tensorflow/core/api_def/python_api/api_def_CollectiveBcastSend.pbtxt
new file mode 100644
index 0000000000..9d6b2f83fe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CollectiveBcastSend.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "CollectiveBcastSend"
+ endpoint {
+ name: "collective.broadcast_send"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CollectiveReduce.pbtxt b/tensorflow/core/api_def/python_api/api_def_CollectiveReduce.pbtxt
new file mode 100644
index 0000000000..27ae8a833a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CollectiveReduce.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "CollectiveReduce"
+ endpoint {
+ name: "collective.all_reduce"
+ }
+}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0afbd02e86..07c1eafedc 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -19,15 +19,19 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -443,6 +447,18 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Create a run state and start execution.
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+ // Set up for collectives if the RunOption declares a key.
+ if (run_options.experimental().collective_graph_key() > 0) {
+ if (!collective_executor_mgr_) {
+ DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ collective_executor_mgr_.reset(new CollectiveExecutorMgr(
+ options_.config, device_mgr_.get(), drl,
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl,
+ "/job:localhost/replica:0/task:0")));
+ }
+ run_state.collective_executor.reset(new CollectiveExecutor::Handle(
+ collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
+ }
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
@@ -459,6 +475,9 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
args.step_id = step_id;
args.call_frame = call_frame;
args.rendezvous = run_state.rendez;
+ args.collective_executor =
+ (run_state.collective_executor ? run_state.collective_executor->get()
+ : nullptr);
CancellationManager step_cancellation_manager;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
@@ -768,6 +787,10 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
args.rendezvous = run_state->rendez;
args.cancellation_manager = cancellation_manager_;
+ // Note that Collectives are not supported in partial runs
+ // because RunOptions is not passed in so we can't know whether
+ // their use is intended.
+ args.collective_executor = nullptr;
args.runner = [this, pool](Executor::Args::Closure c) {
SchedClosure(pool, std::move(c));
};
@@ -1518,11 +1541,13 @@ DirectSession::RunState::RunState(
const std::vector<string>& pending_input_names,
const std::vector<string>& pending_output_names, int64 step_id,
const std::vector<Device*>* devices)
- : step_container(step_id, [devices](const string& name) {
+ : step_container(step_id, [devices, step_id](const string& name) {
for (auto d : *devices) {
if (!d->resource_manager()->Cleanup(name).ok()) {
// Do nothing...
}
+ ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+ if (sam) sam->Cleanup(step_id);
}
}) {
// Initially all the feeds and fetches are pending.
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 6f9c1b980b..72a2be4816 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/tensor.h"
@@ -175,6 +176,7 @@ class DirectSession : public Session {
mutex mu_;
Status status GUARDED_BY(mu_);
IntraProcessRendezvous* rendez = nullptr;
+ std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
std::unique_ptr<StepStatsCollector> collector;
Notification executors_done;
std::unordered_map<string, bool> pending_inputs; // true if fed
@@ -352,6 +354,7 @@ class DirectSession : public Session {
DirectSessionFactory* const factory_; // not owned
CancellationManager* cancellation_manager_;
+ std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index a63b2b9711..2a43a31c02 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -86,6 +87,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
[](std::function<void()> f) { f(); };
params.runner = &runner;
+ ScopedStepContainer step_container(0, [this](const string& name) {
+ device_->resource_manager()->Cleanup(name).IgnoreError();
+ });
+ params.step_container = &step_container;
+
OpKernelContext context(&params);
if (kernel_->def().op() == "_Recv") {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 802bfee890..585d777e81 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -592,7 +593,8 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
}
}
}
- if (fwd_status.ok() && forward_from[i] == -1) {
+ if (fwd_status.ok() &&
+ forward_from[i] == OpKernelContext::Params::kNoReservation) {
DCHECK_EQ(forward_input.size() % 2, 0);
for (int j = 0; j < forward_input.size(); j += 2) {
if (forward_input[j + 1] == i) {
@@ -770,7 +772,8 @@ void GraphView::SetScopedAllocatorAttrs(
<< use_node->name();
continue;
}
- // There should be exactly one output using ScopedAllocation.
+ // There can be more than one output using ScopedAllocation, but this
+ // analysis assumes they use the same ScopedAllocator.
for (const auto& e : use_node->out_edges()) {
if (!e->IsControlEdge()) {
AllocatorAttributes attr;
@@ -887,6 +890,11 @@ Status InferAllocAttr(const Node* n, const Node* dst,
<< " remote type " << parsed_dst_name.type;
}
}
+ if (n->IsCollective()) {
+ // We'll make the sweeping assumption that any collective op is going
+ // to be involved in network i/o.
+ attr->set_nic_compatible(true);
+ }
return s;
}
@@ -1289,6 +1297,7 @@ class ExecutorState {
int64 step_id_;
// Not owned.
Rendezvous* rendezvous_;
+ CollectiveExecutor* collective_executor_ = nullptr;
SessionState* session_state_;
TensorStore* tensor_store_;
// Step-local container.
@@ -1411,6 +1420,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
log_memory_(LogMemory::IsEnabled()),
step_id_(args.step_id),
rendezvous_(args.rendezvous),
+ collective_executor_(args.collective_executor),
session_state_(args.session_state),
tensor_store_(args.tensor_store),
step_container_(args.step_container),
@@ -1621,6 +1631,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.log_memory = log_memory_;
params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
params.rendezvous = rendezvous_;
+ params.collective_executor = collective_executor_;
params.session_state = session_state_;
params.tensor_store = tensor_store_;
params.cancellation_manager = cancellation_manager_;
@@ -2180,6 +2191,9 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
if (rendezvous_) {
rendezvous_->StartAbort(s);
}
+ if (collective_executor_) {
+ collective_executor_->StartAbort(s);
+ }
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
}
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index adf80a2417..e5d7b7c53c 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -89,6 +89,7 @@ class Executor {
SessionState* session_state = nullptr;
TensorStore* tensor_store = nullptr;
ScopedStepContainer* step_container = nullptr;
+ CollectiveExecutor* collective_executor = nullptr;
// If true, calls Sync() on the device.
bool sync_on_finish = false;
diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc
index e34224205b..e34224205b 100644
--- a/tensorflow/core/distributed_runtime/executor_test.cc
+++ b/tensorflow/core/common_runtime/executor_test.cc
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index d05564e9c4..5d9be70522 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -809,6 +810,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
exec_args->cancellation_manager = run_opts.cancellation_manager;
exec_args->step_container = run_opts.step_container;
exec_args->runner = *run_opts.runner;
+ exec_args->collective_executor = run_opts.collective_executor;
Item* item = nullptr;
Status s = GetOrCreateItem(handle, &item);
@@ -896,6 +898,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
exec_args->rendezvous = run_opts.rendezvous;
exec_args->stats_collector = run_opts.stats_collector;
exec_args->cancellation_manager = run_opts.cancellation_manager;
+ exec_args->collective_executor = run_opts.collective_executor;
exec_args->step_container = run_opts.step_container;
exec_args->runner = *run_opts.runner;
exec_args->call_frame = frame;
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index adf2ef6f44..0a1797fa19 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -176,6 +176,9 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
args.step_id = LogMemory::CONSTANT_FOLDING_STEP_ID;
args.runner = runner;
args.rendezvous = rendez;
+ // NOTE: Use of graph runner is limited to single-device executions
+ // so a CollectiveExecutor should never be required.
+ args.collective_executor = nullptr;
// Run the graph.
TF_RETURN_IF_ERROR(executor->Run(args));
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index fe4df1c106..103eee03b3 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,11 +58,6 @@ class RenamedDevice : public Device {
return underlying_->GetAllocator(attr);
}
- Allocator* GetStepAllocator(AllocatorAttributes attr,
- ResourceMgr* step_resource_manager) override {
- return underlying_->GetStepAllocator(attr, step_resource_manager);
- }
-
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
index c045596a69..8ac6adc2e4 100644
--- a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
+++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc
@@ -160,13 +160,18 @@ Status ScopedAllocatorMgr::AddScopedAllocator(
expected_call_count);
}
-void ScopedAllocatorMgr::PopulateFields(
+/*static*/
+size_t ScopedAllocatorMgr::PopulateFields(
int32 scope_id, const gtl::ArraySlice<TensorShape>& shapes,
const DataType dtype, std::vector<ScopedAllocator::Field>* fields) {
const int32 num_fields = static_cast<int32>(shapes.size());
fields->resize(num_fields);
size_t offset = 0;
for (int32 i = 0; i < num_fields; ++i) {
+ size_t overshoot = offset % Allocator::kAllocatorAlignment;
+ if (overshoot > 0) {
+ offset += (Allocator::kAllocatorAlignment - overshoot);
+ }
size_t bytes = shapes[i].num_elements() * DataTypeSize(dtype);
(*fields)[i].scope_id = scope_id + 1 + i;
(*fields)[i].bytes = bytes;
@@ -175,11 +180,8 @@ void ScopedAllocatorMgr::PopulateFields(
<< " bytes=" << (*fields)[i].bytes
<< " offset=" << (*fields)[i].offset;
offset += bytes;
- size_t overshoot = offset % Allocator::kAllocatorAlignment;
- if (overshoot > 0) {
- offset += (Allocator::kAllocatorAlignment - overshoot);
- }
}
+ return offset;
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.h b/tensorflow/core/common_runtime/scoped_allocator_mgr.h
index effc5f2d77..8c5e853472 100644
--- a/tensorflow/core/common_runtime/scoped_allocator_mgr.h
+++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.h
@@ -89,10 +89,13 @@ class ScopedAllocatorMgr {
// Populate the bytes and offset members of Field. Instance allocaters get
// consecutive scope_id values following that of the base ScopedAllocator.
- static void PopulateFields(int32 scope_id,
- const gtl::ArraySlice<TensorShape>& shapes,
- const DataType dtype,
- std::vector<ScopedAllocator::Field>* fields);
+ // Returns the total number of bytes required to be allocated in the
+ // backing tensor, for convenience. (The same value can be obtained
+ // by summing offset and bytes in the last field.)
+ static size_t PopulateFields(int32 scope_id,
+ const gtl::ArraySlice<TensorShape>& shapes,
+ const DataType dtype,
+ std::vector<ScopedAllocator::Field>* fields);
const string& device_name() const { return device_name_; }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/common_runtime/testlib_ops.cc
index 5597ee7a76..a0139c3ee5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc
+++ b/tensorflow/core/common_runtime/testlib_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -21,8 +22,12 @@ namespace tensorflow {
namespace test {
// ErrorOp::Compute returns an error.
-REGISTER_OP("Error").Input("in: T").Output("out: T").Attr("T: type").Attr(
- "message: string");
+REGISTER_OP("Error")
+ .Input("in: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Attr("message: string")
+ .SetShapeFn(shape_inference::UnknownShape);
class ErrorOp : public OpKernel {
public:
explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -41,7 +46,8 @@ REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp);
REGISTER_OP("InvalidRefType")
.Output("out: Ref(TIn)")
.Attr("TIn: type")
- .Attr("TOut: type");
+ .Attr("TOut: type")
+ .SetShapeFn(shape_inference::UnknownShape);
class InvalidRefType : public OpKernel {
public:
explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -63,8 +69,12 @@ REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
// its input.
-REGISTER_OP("Delay").Input("in: T").Output("out: T").Attr("T: type").Attr(
- "micros: int");
+REGISTER_OP("Delay")
+ .Input("in: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Attr("micros: int")
+ .SetShapeFn(shape_inference::UnchangedShape);
class DelayOp : public AsyncOpKernel {
public:
explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 18b7069dbe..ead698d787 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -561,17 +561,19 @@ tf_cc_test(
],
)
-# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
-# on grpc_testlib.
-tf_cuda_cc_tests(
- name = "executor_tests",
+tf_cuda_cc_test(
+ name = "master_test",
size = "medium",
srcs = [
- "executor_test.cc",
- #"master_test.cc", # TODO(b/27683709): Re-enable when not flaky.
+ "master_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
- tags = tf_cuda_tests_tags(),
+ tags = tf_cuda_tests_tags() + [
+ "manual", # TODO(b/27683709): Re-enable when not flaky.
+ "notap", # TODO(b/27683709): Re-enable when not flaky.
+ "noguitar", # TODO(b/27683709): Re-enable when not flaky.
+ "nooss", # TODO(b/27683709): Re-enable when not flaky.
+ ],
deps = [
":master",
":remote_device",
@@ -588,6 +590,7 @@ tf_cuda_cc_tests(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_master_service_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
@@ -648,10 +651,10 @@ tf_cuda_cc_test(
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
- "//tensorflow/core/distributed_runtime/rpc:grpc_testlib_ops",
"//tensorflow/core/kernels:aggregate_ops",
"//tensorflow/core/kernels:array",
],
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
index f2c1f3489c..0826a90860 100644
--- a/tensorflow/core/distributed_runtime/master_test.cc
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "grpc++/grpc++.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/allocator.h"
@@ -37,7 +38,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
-#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 40028ee241..4b2747f26d 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -314,18 +314,6 @@ tf_cc_binary(
],
)
-tf_cuda_library(
- name = "grpc_testlib_ops",
- testonly = 1,
- srcs = ["grpc_testlib_ops.cc"],
- linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- ],
- alwayslink = 1,
-)
-
tf_cc_binary(
name = "grpc_testlib_server",
testonly = 1,
@@ -334,11 +322,11 @@ tf_cc_binary(
],
deps = [
":grpc_server_lib",
- ":grpc_testlib_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:cwise_op",
@@ -362,12 +350,12 @@ tf_cuda_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
":grpc_session",
- ":grpc_testlib_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
+ "//tensorflow/core:testlib",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 71a31b0e75..d1b495d2ff 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -303,6 +303,9 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
if (format == FORMAT_NCHW_VECT_C) {
dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
context->MakeDim(4);
+ } else if (format == FORMAT_NHWC_VECT_W) {
+ dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
+ context->MakeDim(4);
}
for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 223b74857d..ec26d92a61 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -169,13 +169,10 @@ class DeviceBase {
return nullptr;
}
- // Return the Allocator implementation to use based on the allocator
- // attributes requested and the supplied resource manager. By
- // default this ignores the resource manager and calls the base
- // implementation but devices can override if they want to consult
- // the resource manager when choosing the allocator.
- virtual Allocator* GetStepAllocator(AllocatorAttributes attr,
- ResourceMgr* /*step_resource_manager*/) {
+ // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
+ // This method is provided for backwards compatibility, and will be removed
+ // in a future release.
+ Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e00399f97d..872906756a 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -33,6 +33,7 @@ limitations under the License.
namespace tensorflow {
class CancellationManager;
+class CollectiveExecutor;
class GraphDef;
class OpKernel;
class ProcessFunctionLibraryRuntime;
@@ -484,6 +485,7 @@ class FunctionLibraryRuntime {
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
CancellationManager* cancellation_manager = nullptr;
+ CollectiveExecutor* collective_executor = nullptr;
ScopedStepContainer* step_container = nullptr;
StepStatsCollector* stats_collector = nullptr;
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index 0873d4e47b..b8309eafb0 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -97,7 +97,7 @@ Status OpRegistry::LookUp(const string& op_type_name,
"Make sure the Op and Kernel are registered in the "
"binary running in this process. Note that if you "
"are loading a saved graph which used ops from "
- "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done"
+ "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
"before importing the graph, as contrib ops are lazily registered "
"when the module is first accessed.");
VLOG(1) << status.ToString();
@@ -256,7 +256,7 @@ Status OpListOpRegistry::LookUp(const string& op_type_name,
"Make sure the Op and Kernel are registered in the "
"binary running in this process. Note that if you "
"are loading a saved graph which used ops from "
- "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done"
+ "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
"before importing the graph, as contrib ops are lazily registered "
"when the module is first accessed.");
}
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c71bcb26ab..b05a9df7c1 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -283,13 +283,13 @@ OpKernelContext::~OpKernelContext() {
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
Allocator* allocator = nullptr;
- if (attr.scope_id > 0) {
+ if (TF_PREDICT_FALSE(attr.scope_id > 0)) {
allocator = params_->device->GetScopedAllocator(attr, step_id());
CHECK(allocator);
} else {
- allocator = params_->device->GetStepAllocator(attr, resource_manager());
+ allocator = params_->device->GetAllocator(attr);
}
- if (track_allocations()) {
+ if (TF_PREDICT_FALSE(track_allocations())) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
if (wrapped.first == allocator) {
@@ -1273,59 +1273,51 @@ const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
}
#endif
-namespace {
-template <class OpKernelT>
-void CtxFailureInternal(OpKernelT* op_kernel, const char* file, int line,
- const Status& s) {
- const string logging_prefix =
- file == nullptr ? "CtxFailure: "
- : strings::StrCat("CtxFailure at ", io::Basename(file),
- ":", line, ": ");
-
- if (errors::IsOutOfRange(s)) {
- // VLOG OutOfRange errors. Dataset ops create OutOfRange errors when they
- // reach end-of-sequence.
- VLOG(1) << logging_prefix << s;
- } else {
- LOG(WARNING) << logging_prefix << s;
- }
- op_kernel->SetStatus(s);
-}
-} // anonymous namespace
-
void OpKernelConstruction::CtxFailure(const Status& s) {
- CtxFailureInternal(this, nullptr, 0, s);
+ VLOG(1) << s;
+ SetStatus(s);
}
void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
- CtxFailureInternal(this, nullptr, 0, s);
+ LOG(WARNING) << s;
+ SetStatus(s);
}
void OpKernelConstruction::CtxFailure(const char* file, int line,
const Status& s) {
- CtxFailureInternal(this, file, line, s);
+ VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
}
void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
- CtxFailureInternal(this, file, line, s);
+ LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
}
void OpKernelContext::CtxFailure(const Status& s) {
- CtxFailureInternal(this, nullptr, 0, s);
+ VLOG(1) << s;
+ SetStatus(s);
}
void OpKernelContext::CtxFailureWithWarning(const Status& s) {
- CtxFailureInternal(this, nullptr, 0, s);
+ LOG(WARNING) << s;
+ SetStatus(s);
}
void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
- CtxFailureInternal(this, file, line, s);
+ VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
}
void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
- CtxFailureInternal(this, file, line, s);
+ LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
+ << " : " << s;
+ SetStatus(s);
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc
index 39ef82765f..fc3a329b3b 100644
--- a/tensorflow/core/framework/resource_handle.cc
+++ b/tensorflow/core/framework/resource_handle.cc
@@ -66,4 +66,29 @@ string ProtoDebugString(const ResourceHandle& handle) {
return handle.DebugString();
}
+void EncodeResourceHandleList(const ResourceHandle* p, int64 n,
+ std::unique_ptr<port::StringListEncoder> e) {
+ ResourceHandleProto proto;
+ for (int i = 0; i < n; ++i) {
+ p[i].AsProto(&proto);
+ e->Append(proto);
+ }
+ e->Finalize();
+}
+
+bool DecodeResourceHandleList(std::unique_ptr<port::StringListDecoder> d,
+ ResourceHandle* ps, int64 n) {
+ std::vector<uint32> sizes(n);
+ if (!d->ReadSizes(&sizes)) return false;
+
+ ResourceHandleProto proto;
+ for (int i = 0; i < n; ++i) {
+ if (!proto.ParseFromArray(d->Data(sizes[i]), sizes[i])) {
+ return false;
+ }
+ ps[i].FromProto(proto);
+ }
+ return true;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_handle.h b/tensorflow/core/framework/resource_handle.h
index 06df1b9046..db213669a3 100644
--- a/tensorflow/core/framework/resource_handle.h
+++ b/tensorflow/core/framework/resource_handle.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_
#define TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_
+#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -77,6 +78,14 @@ class ResourceHandle {
// For backwards compatibility for when this was a proto
string ProtoDebugString(const ResourceHandle& handle);
+// Encodes a list of ResourceHandle protos in the given StringListEncoder.
+void EncodeResourceHandleList(const ResourceHandle* p, int64 n,
+ std::unique_ptr<port::StringListEncoder> e);
+
+// Decodes a list of ResourceHandle protos from the given StringListDecoder.
+bool DecodeResourceHandleList(std::unique_ptr<port::StringListDecoder> d,
+ ResourceHandle* ps, int64 n);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_RESOURCE_HANDLE_H_
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 78574bc0b1..21fc6c1bd5 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -138,16 +138,13 @@ string ResourceMgr::DebugString() const {
Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
const string& name, ResourceBase* resource) {
- {
- mutex_lock l(mu_);
- Container** b = &containers_[container];
- if (*b == nullptr) {
- *b = new Container;
- }
- if ((*b)->insert({{type.hash_code(), name}, resource}).second) {
- TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name()));
- return Status::OK();
- }
+ Container** b = &containers_[container];
+ if (*b == nullptr) {
+ *b = new Container;
+ }
+ if ((*b)->insert({{type.hash_code(), name}, resource}).second) {
+ TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name()));
+ return Status::OK();
}
resource->Unref();
return errors::AlreadyExists("Resource ", container, "/", name, "/",
@@ -157,7 +154,6 @@ Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
const string& name,
ResourceBase** resource) const {
- tf_shared_lock l(mu_);
const Container* b = gtl::FindPtrOrNull(containers_, container);
if (b == nullptr) {
return errors::NotFound("Container ", container,
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 621da5b838..33d4cb77ff 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -131,6 +131,10 @@ class ResourceMgr {
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".
//
+ // WARNING: creator() must not call any methods on ResourceMgr during its
+ // execution, because a non-reentrant lock is held during the creator() call
+ // in order to guarantee atomicity of LookupOrCreate().
+ //
// REQUIRES: std::is_base_of<ResourceBase, T>
// REQUIRES: resource != nullptr
template <typename T>
@@ -174,10 +178,19 @@ class ResourceMgr {
mutable mutex mu_;
std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
+ template <typename T>
+ Status LookupInternal(const string& container, const string& name,
+ T** resource) const
+ SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
+
Status DoCreate(const string& container, TypeIndex type, const string& name,
- ResourceBase* resource) TF_MUST_USE_RESULT;
+ ResourceBase* resource)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
+
Status DoLookup(const string& container, TypeIndex type, const string& name,
- ResourceBase** resource) const TF_MUST_USE_RESULT;
+ ResourceBase** resource) const
+ SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
+
Status DoDelete(const string& container, uint64 type_hash_code,
const string& resource_name,
const string& type_name) TF_MUST_USE_RESULT;
@@ -362,6 +375,7 @@ Status ResourceMgr::Create(const string& container, const string& name,
T* resource) {
CheckDeriveFromResourceBase<T>();
CHECK(resource != nullptr);
+ mutex_lock l(mu_);
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
}
@@ -369,6 +383,13 @@ template <typename T>
Status ResourceMgr::Lookup(const string& container, const string& name,
T** resource) const {
CheckDeriveFromResourceBase<T>();
+ tf_shared_lock l(mu_);
+ return LookupInternal(container, name, resource);
+}
+
+template <typename T>
+Status ResourceMgr::LookupInternal(const string& container, const string& name,
+ T** resource) const {
ResourceBase* found = nullptr;
Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
if (s.ok()) {
@@ -383,21 +404,23 @@ template <typename T>
Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
T** resource,
std::function<Status(T**)> creator) {
- Status s;
+ CheckDeriveFromResourceBase<T>();
*resource = nullptr;
- while (*resource == nullptr) {
- s = Lookup(container, name, resource);
- if (s.ok()) break;
- s = creator(resource);
- if (!s.ok()) break;
- s = Create(container, name, *resource);
- if (s.ok()) {
- (*resource)->Ref();
- break;
- }
- // Rare event. Concurrent racy creation. Redo the lookup.
- *resource = nullptr;
+ Status s;
+ {
+ tf_shared_lock l(mu_);
+ s = LookupInternal(container, name, resource);
+ if (s.ok()) return s;
+ }
+ mutex_lock l(mu_);
+ s = LookupInternal(container, name, resource);
+ if (s.ok()) return s;
+ TF_RETURN_IF_ERROR(creator(resource));
+ s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
+ if (!s.ok()) {
+ return errors::Internal("LookupOrCreate failed unexpectedly");
}
+ (*resource)->Ref();
return s;
}
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
index 798220d4c3..7c7f0af0ce 100644
--- a/tensorflow/core/framework/resource_mgr_test.cc
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -124,7 +125,7 @@ TEST(ResourceMgrTest, Basic) {
TF_CHECK_OK(rm.Cleanup("bar"));
}
-TEST(ResourceMgr, CreateOrLookup) {
+TEST(ResourceMgrTest, CreateOrLookup) {
ResourceMgr rm;
EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "cat"));
EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "dog"));
@@ -136,6 +137,30 @@ TEST(ResourceMgr, CreateOrLookup) {
HasError(FindErr<Other>(rm, "foo", "bar"), "Not found: Resource foo/bar");
}
+TEST(ResourceMgrTest, CreateOrLookupRaceCondition) {
+ ResourceMgr rm;
+ std::atomic<int> atomic_int(0);
+ {
+ thread::ThreadPool threads(Env::Default(), "racing_creates", 2);
+ for (int i = 0; i < 2; i++) {
+ threads.Schedule([&rm, &atomic_int] {
+ Resource* r;
+ TF_CHECK_OK(rm.LookupOrCreate<Resource>(
+ "container", "resource-name", &r, [&atomic_int](Resource** ret) {
+ // Maximize chance of encountering race condition if one exists.
+ Env::Default()->SleepForMicroseconds(1 * 1000 * 1000);
+ atomic_int += 1;
+ *ret = new Resource("label");
+ return Status::OK();
+ }));
+ r->Unref();
+ });
+ }
+ }
+ // Resource creator function should always run exactly once.
+ EXPECT_EQ(1, atomic_int);
+}
+
Status ComputePolicy(const string& attr_container,
const string& attr_shared_name,
bool use_node_name_as_default, string* result) {
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 167e0eaa6e..384a42fc11 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -51,7 +51,6 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/platform/variant_coding.h"
namespace tensorflow {
@@ -207,7 +206,8 @@ struct Helper<ResourceHandle> {
// "out", which is usually the TensorProto::tensor_content.
template <typename Destination>
static void Encode(TensorBuffer* in, int64 n, Destination* out) {
- port::EncodeResourceHandleList(in->base<const ResourceHandle>(), n, out);
+ EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
+ port::NewStringListEncoder(out));
}
// Decodes "n" elements of type string from "in" and constructs a
@@ -217,7 +217,8 @@ struct Helper<ResourceHandle> {
static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
auto* buf = new Buffer<ResourceHandle>(a, n);
ResourceHandle* ps = buf->template base<ResourceHandle>();
- if (ps == nullptr || !port::DecodeResourceHandleList(in, ps, n)) {
+ if (ps == nullptr ||
+ !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
buf->Unref();
return nullptr;
}
@@ -237,7 +238,8 @@ struct Helper<Variant> {
// "out", which is usually the TensorProto::tensor_content.
template <typename Destination>
static void Encode(TensorBuffer* in, int64 n, Destination* out) {
- port::EncodeVariantList(in->base<const Variant>(), n, out);
+ EncodeVariantList(in->base<const Variant>(), n,
+ port::NewStringListEncoder(out));
}
// Decodes "n" elements of type Variant from "in" and constructs a
@@ -247,7 +249,8 @@ struct Helper<Variant> {
static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
auto* buf = new Buffer<Variant>(a, n);
Variant* ps = buf->template base<Variant>();
- if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) {
+ if (ps == nullptr ||
+ !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
buf->Unref();
return nullptr;
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 58fbced606..d2f2609d3b 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -484,6 +484,7 @@ class Tensor {
friend class TensorTestHelper; // For access to set_shape
friend class OpKernelContext; // For access to RefCountIsOne().
friend class ScopedAllocator; // For access to buf_.
+ friend class XlaTensor; // For access to RefCountIsOne().
friend class XlaTensorBuffer; // For access to the private constructor taking
// the buffer
template <typename Device, typename T>
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 6ad2fafee7..5a507804b0 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -73,4 +74,36 @@ bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
return value->ParseFromString(buf);
}
+void EncodeVariantList(const Variant* variant_array, int64 n,
+ std::unique_ptr<port::StringListEncoder> e) {
+ for (int i = 0; i < n; ++i) {
+ string s;
+ variant_array[i].Encode(&s);
+ e->Append(s);
+ }
+ e->Finalize();
+}
+
+bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
+ Variant* variant_array, int64 n) {
+ std::vector<uint32> sizes(n);
+ if (!d->ReadSizes(&sizes)) return false;
+
+ for (int i = 0; i < n; ++i) {
+ if (variant_array[i].is_empty()) {
+ variant_array[i] = VariantTensorDataProto();
+ }
+ string str(d->Data(sizes[i]), sizes[i]);
+ if (!variant_array[i].Decode(str)) return false;
+ if (!DecodeUnaryVariant(&variant_array[i])) {
+ LOG(ERROR) << "Could not decode variant with type_name: \""
+ << variant_array[i].TypeName()
+ << "\". Perhaps you forgot to register a "
+ "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
+ return false;
+ }
+ }
+ return true;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index 5a84f9d943..ded04b2a30 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -259,6 +259,16 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf);
template <>
bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+// Encodes an array of Variant objects in to the given StringListEncoder.
+// `variant_array` is assumed to point to an array of `n` Variant objects.
+void EncodeVariantList(const Variant* variant_array, int64 n,
+ std::unique_ptr<port::StringListEncoder> e);
+
+// Decodes an array of Variant objects from the given StringListDecoder.
+// `variant_array` is assumed to point to an array of `n` Variant objects.
+bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
+ Variant* variant_array, int64 n);
+
} // end namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 71d0637dc2..0f748515ef 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -80,6 +80,9 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"Shape", NC_METADATA},
{"Rank", NC_METADATA},
{"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
+ {"CollectiveReduce", NC_COLLECTIVE},
+ {"CollectiveBcastSend", NC_COLLECTIVE},
+ {"CollectiveBcastRecv", NC_COLLECTIVE},
});
#undef REF_CLASS
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 83a69e6b2d..33fb7cb57a 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -163,6 +163,7 @@ class Node {
bool IsHostSend() const { return class_ == NC_HOST_SEND; }
bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
+ bool IsCollective() const { return class_ == NC_COLLECTIVE; }
bool IsMetadata() const { return class_ == NC_METADATA; }
@@ -235,6 +236,7 @@ class Node {
NC_DELETE_SESSION_TENSOR,
NC_METADATA,
NC_SCOPED_ALLOCATOR,
+ NC_COLLECTIVE,
NC_OTHER // Not a special kind of node
};
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index c54b4fa269..6309870190 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -3170,7 +3170,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UnknownOps) {
{"Make sure the Op and Kernel are registered in the "
"binary running in this process. Note that if you "
"are loading a saved graph which used ops from "
- "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done"
+ "tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
"before importing the graph, as contrib ops are lazily registered "
"when the module is first accessed."});
}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 92581942cb..2a47a4c495 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -78,6 +78,12 @@ bool IsCheckNumerics(const NodeDef& node) {
return node.op() == "CheckNumerics";
}
+bool IsCollective(const NodeDef& node) {
+ return node.op() == "CollectiveReduce" ||
+ node.op() == "CollectiveBcastSend" ||
+ node.op() == "CollectiveBcastRecv";
+}
+
bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
@@ -203,6 +209,8 @@ bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
+bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
+
bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
bool IsMerge(const NodeDef& node) {
@@ -445,6 +453,10 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
return false;
}
}
+ // Queue ops modify the queue which is a side effect.
+ if (node.op().find("Queue") != std::string::npos) {
+ return false;
+ }
return !ModifiesInputsInPlace(node);
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 9d91ba1ba5..e7f39981c0 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -38,6 +38,7 @@ bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsCast(const NodeDef& node);
bool IsCheckNumerics(const NodeDef& node);
+bool IsCollective(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
@@ -78,6 +79,7 @@ bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
bool IsMax(const NodeDef& node);
bool IsMaximum(const NodeDef& node);
+bool IsMaxPoolGrad(const NodeDef& node);
bool IsMean(const NodeDef& node);
bool IsMerge(const NodeDef& node);
bool IsMin(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 104a0428ce..c90667abad 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -517,6 +517,7 @@ cc_library(
":memory_optimizer",
":model_pruner",
":remapper",
+ ":scoped_allocator_optimizer",
":shape_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
@@ -695,6 +696,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
@@ -761,3 +763,47 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
+
+cc_library(
+ name = "scoped_allocator_optimizer",
+ srcs = ["scoped_allocator_optimizer.cc"],
+ hdrs = [
+ "scoped_allocator_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:scoped_allocator_ops_op_lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ ],
+)
+
+tf_cc_test(
+ name = "scoped_allocator_optimizer_test",
+ size = "small",
+ srcs = ["scoped_allocator_optimizer_test.cc"],
+ deps = [
+ ":scoped_allocator_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 8bdb164b03..1ea916a250 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1631,39 +1631,187 @@ Status ConstantFolding::ReplaceOperationWithConstant(
return Status::OK();
}
-Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
- GraphProperties* properties,
- bool use_shape_info) {
+Status ConstantFolding::SimplifyGraph(bool use_shape_info,
+ GraphDef* optimized_graph,
+ GraphProperties* properties) {
for (int i = 0; i < optimized_graph->node_size(); ++i) {
- TF_RETURN_IF_ERROR(SimplifyNode(optimized_graph->mutable_node(i),
- optimized_graph, properties,
- use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyNode(use_shape_info,
+ optimized_graph->mutable_node(i),
+ optimized_graph, properties));
}
return Status::OK();
}
-Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
- GraphProperties* properties,
- bool use_shape_info) {
- if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
+ GraphDef* optimized_graph,
+ GraphProperties* properties) {
+ if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
return Status::OK();
}
- if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ bool remove_shuffle_transpose_successful = false;
+ Status remove_shuffle_transpose_status =
+ RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
+ node, &remove_shuffle_transpose_successful);
+ if (!remove_shuffle_transpose_status.ok()) {
+ return remove_shuffle_transpose_status;
+ } else if (remove_shuffle_transpose_successful) {
+ return Status::OK();
+ }
+
+ if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
+ return Status::OK();
+ }
+
+ bool remove_reverse_successful = false;
+ Status remove_reverse_status =
+ RemoveReverse(*properties, use_shape_info, optimized_graph, node,
+ &remove_reverse_successful);
+ if (!remove_reverse_status.ok()) {
+ return remove_reverse_status;
+ } else if (remove_reverse_successful) {
+ return Status::OK();
+ }
+
+ bool simplify_slice_successful = false;
+ Status simplify_slice_status =
+ SimplifySlice(*properties, use_shape_info, optimized_graph, node,
+ &simplify_slice_successful);
+ if (!simplify_slice_status.ok()) {
+ return simplify_slice_status;
+ } else if (simplify_slice_successful) {
+ return Status::OK();
+ }
+
+ bool simplify_strided_slice_successful = false;
+ Status simplify_strided_slice_status =
+ SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
+ &simplify_strided_slice_successful);
+ if (!simplify_strided_slice_status.ok()) {
+ return simplify_strided_slice_status;
+ } else if (simplify_strided_slice_successful) {
+ return Status::OK();
+ }
+
+ bool simplify_tile_successful = false;
+ Status simplify_tile_status =
+ SimplifyTile(*properties, use_shape_info, optimized_graph, node,
+ &simplify_tile_successful);
+ if (!simplify_tile_status.ok()) {
+ return simplify_tile_status;
+ } else if (simplify_tile_successful) {
+ return Status::OK();
+ }
+
+ bool simplify_pad_successful = false;
+ Status simplify_pad_status =
+ SimplifyPad(*properties, use_shape_info, optimized_graph, node,
+ &simplify_pad_successful);
+ if (!simplify_pad_status.ok()) {
+ return simplify_pad_status;
+ } else if (simplify_pad_successful) {
+ return Status::OK();
+ }
+
+ if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
+ return Status::OK();
+ }
+
+ if (SimplifyPack(optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (MoveConstantsPastEnter(optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (SimplifySwitch(optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (SimplifyReduction(*properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (SimplifyReshape(*properties, use_shape_info, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ bool arithmetic_simplification_succeed = false;
+ Status simplify_arithmetic_status =
+ SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
+ node, &arithmetic_simplification_succeed);
+ if (!simplify_arithmetic_status.ok()) {
+ return simplify_arithmetic_status;
+ } else if (arithmetic_simplification_succeed) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (ReduceDivToReciprocalMul(optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (ConstantPushDown(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (MulConvPushDown(node, *properties)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConstPropThroughIdentityN(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
return Status::OK();
}
- // Remove Shuffle or Transpose op over dimensions of size 1.
+ return Status::OK();
+}
+
+bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
+ return true;
+ }
+
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
+ }
+ return false;
+}
+
+Status ConstantFolding::RemoveShuffleOrTranspose(
+ const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) {
// Not optimizable.
return Status::OK();
}
- const auto& p = properties->GetInputProperties(node->name())[1];
+ const auto& p = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor perm(p.dtype(), p.shape());
if (!perm.FromProto(p.value())) {
@@ -1690,34 +1838,45 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
-
- // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
+ *success = false;
+ return Status::OK();
+}
+bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
if (use_shape_info && IsRandomShuffle(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ !properties.GetInputProperties(node->name()).empty()) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || first dim is of size 1)
if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
}
}
+ return false;
+}
- // Remove Reverse op over dimensions with size 1.
+Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
if (use_shape_info && node->op() == "ReverseV2" &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) {
// Not optimizable.
return Status::OK();
}
- const auto& a = properties->GetInputProperties(node->name())[1];
+ const auto& a = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(a.shape()) && a.has_value()) {
Tensor axis(a.dtype(), a.shape());
if (!axis.FromProto(a.value())) {
@@ -1746,17 +1905,25 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
target_axes.find(j) == target_axes.end();
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
+ *success = false;
+ return Status::OK();
+}
+Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
if (use_shape_info && IsSlice(*node) &&
- properties->GetInputProperties(node->name()).size() == 3) {
- const auto& input = properties->GetInputProperties(node->name())[0];
- const auto& b = properties->GetInputProperties(node->name())[1];
- const auto& s = properties->GetInputProperties(node->name())[2];
+ properties.GetInputProperties(node->name()).size() == 3) {
+ const auto& input = properties.GetInputProperties(node->name())[0];
+ const auto& b = properties.GetInputProperties(node->name())[1];
+ const auto& s = properties.GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
@@ -1787,30 +1954,38 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
}
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
+ *success = false;
+ return Status::OK();
+}
+Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node, bool* success) {
if (use_shape_info && IsStridedSlice(*node) &&
- properties->GetInputProperties(node->name()).size() == 4) {
+ properties.GetInputProperties(node->name()).size() == 4) {
if (node->attr().at("new_axis_mask").i() != 0 ||
node->attr().at("shrink_axis_mask").i() != 0) {
// Skip nodes with new/shrink axis mask, since they involve dimension
// changes.
return Status::OK();
}
- const auto& input = properties->GetInputProperties(node->name())[0];
+ const auto& input = properties.GetInputProperties(node->name())[0];
for (int j = 0; j < input.shape().dim_size(); ++j) {
// Skip if input shape is not fully determined.
if (input.shape().dim(j).size() < 0) {
return Status::OK();
}
}
- const auto& b = properties->GetInputProperties(node->name())[1];
- const auto& e = properties->GetInputProperties(node->name())[2];
- const auto& s = properties->GetInputProperties(node->name())[3];
+ const auto& b = properties.GetInputProperties(node->name())[1];
+ const auto& e = properties.GetInputProperties(node->name())[2];
+ const auto& s = properties.GetInputProperties(node->name())[3];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(e.shape()) && e.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
@@ -1879,15 +2054,23 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
(end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
+ *success = false;
+ return Status::OK();
+}
+Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
if (use_shape_info && IsTile(*node) &&
- properties->GetInputProperties(node->name()).size() == 2) {
- const auto& m = properties->GetInputProperties(node->name())[1];
+ properties.GetInputProperties(node->name()).size() == 2) {
+ const auto& m = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
@@ -1907,15 +2090,23 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
}
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
+ *success = false;
+ return Status::OK();
+}
+Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success) {
if (use_shape_info && IsPad(*node) &&
- properties->GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties->GetInputProperties(node->name())[1];
+ properties.GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
@@ -1931,18 +2122,26 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
replaceable &= flatten(j) == 0;
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
}
+ *success = false;
+ return Status::OK();
+}
+bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
if (use_shape_info && IsSqueeze(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
+ !properties.GetInputProperties(node->name()).empty()) {
// https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
// error to squeeze a dimension that is not 1, so we only need to check
// whether the input has > 1 size for each dimension.
- const auto& shape = properties->GetInputProperties(node->name())[0].shape();
+ const auto& shape = properties.GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size > 1)
bool replaceable = !shape.unknown_rank();
@@ -1950,11 +2149,14 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
replaceable &= shape.dim(j).size() > 1;
}
if (replaceable) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
- return Status::OK();
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
}
}
+ return false;
+}
+bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
!OptimizedNodeExists(*node, "_const_axis")) {
// Create constant axis node.
@@ -1965,7 +2167,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
- return Status::OK();
+ return false;
}
// Add a control dependency to make sure axis_node is in the right frame.
const string ctrl_dep = ConstantFolding::AddControlDependency(
@@ -1983,16 +2185,18 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
node->add_input(axis_node->name());
if (node->input_size() > 2) {
node->mutable_input()->SwapElements(1, node->input_size() - 1);
+ return true;
}
- graph_modified_ = true;
- return Status::OK();
}
+ return false;
+}
- // Move constants past Enter.
+bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
+ NodeDef* node) {
if (IsEnter(*node) && node->input_size() > 0) {
if (node->attr().count("is_constant") == 0 ||
!node->attr().at("is_constant").b()) {
- return Status::OK();
+ return false;
}
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
@@ -2029,28 +2233,14 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
}
}
}
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
}
}
+ return false;
+}
- // Switch(x, x) will always feed false to its false branch and true to
- // its true branch. By rewriting the graph a bit, we can propagate these
- // constants down the two output branches, and just use control dependencies
- // to trigger the selected one at runtime. For example,
- //
- // +------+
- // x-->|Switch|-->a (in practice there may be multiple consumers of each
- // x-->| |-->b output branch.)
- // +------+
- //
- // Is rewritten as
- //
- // +------+
- // x-->|Switch|-->Identity--^>Const(false)-->a
- // x-->| |-->Identity--^>Const(true)-->b
- // +------+
+bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
if (node->op() == "Switch" && node->input(0) == node->input(1) &&
!OptimizedNodeExists(*node, "_const_false") &&
!OptimizedNodeExists(*node, "_const_true")) {
@@ -2087,7 +2277,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
false_node->set_name(OptimizedNodeName(*node, "_const_false"));
if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
.ok()) {
- return Status::OK();
+ return false;
}
false_node->set_device(node->device());
@@ -2095,7 +2285,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
true_node->set_name(OptimizedNodeName(*node, "_const_true"));
if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
.ok()) {
- return Status::OK();
+ return false;
}
true_node->set_device(node->device());
@@ -2129,11 +2319,15 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
}
}
}
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
}
- if (IsSimplifiableReduction(*node, *properties)) {
+ return false;
+}
+
+bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
+ NodeDef* node) {
+ if (IsSimplifiableReduction(*node, properties)) {
// Replace the reduction node with an identity node, that can be further
// optimized by the model pruner.
DataType output_type;
@@ -2147,66 +2341,26 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
node->clear_attr();
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
- graph_modified_ = true;
- return Status::OK();
- }
- if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
- DataType output_type = node->attr().at("T").type();
- node->set_op("Identity");
- node->clear_attr();
- (*node->mutable_attr())["T"].set_type(output_type);
- *node->mutable_input(1) = AsControlDependency(node->input(1));
- graph_modified_ = true;
- return Status::OK();
- }
-
- bool arithmetic_simplification_succeed = false;
- Status simplify_arithmetic_status = SimplifyArithmeticOperations(
- optimized_graph, properties, node, use_shape_info,
- &arithmetic_simplification_succeed);
- if (!simplify_arithmetic_status.ok()) {
- return simplify_arithmetic_status;
- } else if (arithmetic_simplification_succeed) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (ReduceDivToReciprocalMul(optimized_graph, node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (ConstantPushDown(node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (MulConvPushDown(node, *properties)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialConstPropThroughIdentityN(node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialConcatConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
+ return false;
+}
- return Status::OK();
+bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
+ bool use_shape_info, NodeDef* node) {
+ if (!use_shape_info) return false;
+ if (!IsSimplifiableReshape(*node, properties)) return false;
+ DataType output_type = node->attr().at("T").type();
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ return true;
}
Status ConstantFolding::SimplifyArithmeticOperations(
- GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node,
- bool use_shape_info, bool* success) {
+ const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success) {
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
@@ -2215,8 +2369,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties->HasInputProperties(node->name()) &&
- properties->HasOutputProperties(node->name())) {
+ properties.HasInputProperties(node->name()) &&
+ properties.HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
@@ -2224,19 +2378,19 @@ Status ConstantFolding::SimplifyArithmeticOperations(
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties->GetOutputProperties(node->name())[0].shape();
+ properties.GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties->GetInputProperties(node->name())[1].shape();
+ properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
- ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+ ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
@@ -2259,14 +2413,14 @@ Status ConstantFolding::SimplifyArithmeticOperations(
}
const TensorShapeProto& x_shape =
- properties->GetInputProperties(node->name())[0].shape();
+ properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
- ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+ ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
@@ -2276,9 +2430,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
bool replace_succeed = false;
- Status replace_op_status =
- ReplaceOperationWithConstant(1, *properties, output_shape, node,
- optimized_graph, &replace_succeed);
+ Status replace_op_status = ReplaceOperationWithConstant(
+ 1, properties, output_shape, node, optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
@@ -2296,7 +2449,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
if (shp.IsFullyDefined()) {
bool replace_succeed = false;
Status replace_op_status =
- ReplaceOperationWithConstant(0, *properties, output_shape, node,
+ ReplaceOperationWithConstant(0, properties, output_shape, node,
optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
@@ -2309,11 +2462,11 @@ Status ConstantFolding::SimplifyArithmeticOperations(
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} else if (is_mul && y_is_zero && y_matches_output_shape) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+ ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
@@ -2855,7 +3008,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
TF_RETURN_IF_ERROR(FoldGraph(optimized_graph));
node_map_.reset(new NodeMap(optimized_graph));
TF_RETURN_IF_ERROR(
- SimplifyGraph(optimized_graph, &properties, can_use_shape_info));
+ SimplifyGraph(can_use_shape_info, optimized_graph, &properties));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index e477934f30..b42d5f201e 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -97,10 +97,10 @@ class ConstantFolding : public GraphOptimizer {
const GraphProperties& properties) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
- bool use_shape_info);
- Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
- GraphProperties* properties, bool use_shape_info);
+ Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
+ GraphProperties* properties);
+ Status SimplifyNode(bool use_shape_info, NodeDef* node,
+ GraphDef* optimized_graph, GraphProperties* properties);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
@@ -134,11 +134,81 @@ class ConstantFolding : public GraphOptimizer {
// Simplifies arithmetic operations with ones or zeros. Returns the status,
// and updates the success input argument that denotes if any simplification
// was applied.
- Status SimplifyArithmeticOperations(GraphDef* optimized_graph,
- GraphProperties* properties,
- NodeDef* node, bool use_shape_info,
+ Status SimplifyArithmeticOperations(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
bool* success);
+ // Simplifies a Reshape operation to an Identity operation if applicable.
+ bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
+ NodeDef* node);
+
+ // Simplifies a Reduction operation to an Identity operation if applicable.
+ bool SimplifyReduction(const GraphProperties& properties, NodeDef* node);
+
+ // Switch(x, x) will always feed false to its false branch and true to
+ // its true branch. By rewriting the graph a bit, we can propagate these
+ // constants down the two output branches, and just use control dependencies
+ // to trigger the selected one at runtime. For example,
+ //
+ // +------+
+ // x-->|Switch|-->a (in practice there may be multiple consumers of each
+ // x-->| |-->b output branch.)
+ // +------+
+ //
+ // Is rewritten as
+ //
+ // +------+
+ // x-->|Switch|-->Identity--^>Const(false)-->a
+ // x-->| |-->Identity--^>Const(true)-->b
+ // +------+
+ bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);
+
+ // Moves constants past Enter node if applicable.
+ bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node);
+
+ // Simplifies Pack operation if applicable.
+ bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
+
+ // Simplifies a Squeeze operation to an Identity operation if applicable.
+ bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
+ // Simplifies a Pad operation to an Identity operation if applicable.
+ Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+ // Simplifies a Tile operation to an Identity operation if applicable.
+ Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+ // Simplifies a StridedSlice operation to an Identity operation if applicable.
+ Status SimplifyStridedSlice(const GraphProperties& properties,
+ bool use_shape_info, GraphDef* optimized_graph,
+ NodeDef* node, bool* success);
+
+ // Simplifies a Slice operation to an Identity operation if applicable.
+ Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+ // Removes Reverse op over dimensions with size 1.
+ Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+ // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
+ bool RemoveRandomShuffle(const GraphProperties& properties,
+ bool use_shape_info, GraphDef* optimized_graph,
+ NodeDef* node);
+
+ // Removes Shuffle or Transpose op over dimensions of size 1.
+ Status RemoveShuffleOrTranspose(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node,
+ bool* success);
+
+ // Removes Split or SplitV node if possible.
+ bool RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph, NodeDef* node);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
new file mode 100644
index 0000000000..d3fe7df583
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -0,0 +1,76 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+cc_library(
+ name = "graph_utils",
+ srcs = ["graph_utils.cc"],
+ hdrs = [
+ "graph_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:grappler_item_builder",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "graph_utils_test",
+ srcs = ["graph_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "map_and_batch_fusion",
+ srcs = ["map_and_batch_fusion.cc"],
+ hdrs = [
+ "map_and_batch_fusion.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_and_batch_fusion_test",
+ srcs = ["map_and_batch_fusion_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_and_batch_fusion",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
+ name = "data",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":map_and_batch_fusion",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
new file mode 100644
index 0000000000..df12de37da
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+namespace {
+
+int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate,
+ const GraphDef& graph) {
+ for (int i = 0; i < graph.node_size(); ++i) {
+ if (predicate(graph.node(i))) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+std::vector<int> CreateNameIndex(const GraphDef& graph) {
+ std::map<string, int> names;
+ for (int i = 0; i < graph.node_size(); ++i) {
+ names[graph.node(i).name()] = i;
+ }
+ std::vector<int> index(graph.node_size());
+ int i = 0;
+ for (const auto& pair : names) {
+ index[i++] = pair.second;
+ }
+ return index;
+}
+
+std::vector<int> CreateInputIndex(const NodeDef& node) {
+ std::map<string, int> inputs;
+ for (int i = 0; i < node.input_size(); ++i) {
+ inputs[node.input(i)] = i;
+ }
+ std::vector<int> index(node.input_size());
+ int i = 0;
+ for (const auto& pair : inputs) {
+ index[i++] = pair.second;
+ }
+ return index;
+}
+
+Status AddScalarConstNodeHelper(
+ DataType dtype, const std::function<void(TensorProto*)>& add_value,
+ GraphDef* graph, NodeDef** result) {
+ NodeDef* node = graph->add_node();
+ const string& name = strings::StrCat("Const/_", graph->node_size());
+ node->set_name(name);
+ node->set_op("Const");
+ (*node->mutable_attr())["dtype"].set_type(dtype);
+ std::unique_ptr<tensorflow::TensorProto> tensor =
+ tensorflow::MakeUnique<tensorflow::TensorProto>();
+ std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
+ tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
+ tensor->set_allocated_tensor_shape(tensor_shape.release());
+ tensor->set_dtype(dtype);
+ add_value(tensor.get());
+ (*node->mutable_attr())["value"].set_allocated_tensor(tensor.release());
+ *result = node;
+ return Status::OK();
+}
+
+} // namespace
+
+Status AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ GraphDef* graph, NodeDef** result) {
+ NodeDef* node = graph->add_node();
+ if (!name.empty()) {
+ node->set_name(name);
+ } else {
+ node->set_name(strings::StrCat(op, "/_", graph->node_size()));
+ }
+ node->set_op(op);
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ *result = node;
+ return Status::OK();
+}
+
+template <>
+Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph,
+ result);
+}
+
+template <>
+Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph,
+ result);
+}
+
+template <>
+Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph,
+ result);
+}
+
+template <>
+Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph,
+ result);
+}
+
+template <>
+Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph,
+ result);
+}
+
+template <>
+Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result) {
+ return AddScalarConstNodeHelper(
+ DT_STRING,
+ [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
+ graph, result);
+}
+
+bool Compare(const GraphDef& g1, const GraphDef& g2) {
+ if (g1.node_size() != g2.node_size()) {
+ return false;
+ }
+ std::vector<int> name_index1 = CreateNameIndex(g1);
+ std::vector<int> name_index2 = CreateNameIndex(g2);
+ for (int i = 0; i < g1.node_size(); ++i) {
+ int idx1 = name_index1[i];
+ int idx2 = name_index2[i];
+ if (g1.node(idx1).op() != g2.node(idx2).op()) {
+ return false;
+ }
+ if (g1.node(idx1).name() != g2.node(idx2).name()) {
+ return false;
+ }
+ if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
+ return false;
+ }
+ std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
+ std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
+ for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
+ if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
+ g2.node(idx2).input(input_index2[j]))) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool ContainsNodeWithName(const string& name, const GraphDef& graph) {
+ return FindNodeWithName(name, graph) != -1;
+}
+
+bool ContainsNodeWithOp(const string& op, const GraphDef& graph) {
+ return FindNodeWithOp(op, graph) != -1;
+}
+
+Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph) {
+ int last = graph->node_size() - 1;
+ for (int i = graph->node_size() - 1; i >= 0; --i) {
+ const NodeDef& node = graph->node(i);
+ if (nodes_to_delete.find(node.name()) != nodes_to_delete.end()) {
+ graph->mutable_node()->SwapElements(i, last);
+ last--;
+ }
+ }
+ graph->mutable_node()->DeleteSubrange(last + 1,
+ graph->node_size() - last - 1);
+ return Status::OK();
+}
+
+int FindNodeWithName(const string& name, const GraphDef& graph) {
+ return FindNodeWithPredicate(
+ [name](const NodeDef& node) { return node.name() == name; }, graph);
+}
+
+int FindNodeWithOp(const string& op, const GraphDef& graph) {
+ return FindNodeWithPredicate(
+ [op](const NodeDef& node) { return node.op() == op; }, graph);
+}
+
+} // end namespace graph_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
new file mode 100644
index 0000000000..b40ca44d78
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+
+// Adds a node to the graph.
+Status AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ GraphDef* graph, NodeDef** result);
+
+// Adds a Const node with the given value to the graph.
+template <typename T>
+Status AddScalarConstNode(T v, GraphDef* graph, NodeDef** result) {
+ return errors::Unimplemented("Type %s is not supported.",
+ DataTypeToEnum<T>::value);
+}
+template <>
+Status AddScalarConstNode(bool v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(double v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(float v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(int v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(int64 v, GraphDef* graph, NodeDef** result);
+template <>
+Status AddScalarConstNode(StringPiece v, GraphDef* graph, NodeDef** result);
+
+// Checks whether the two graphs are the same.
+bool Compare(const GraphDef& g1, const GraphDef& g2);
+
+// Checks whether the graph contains a node with the given name.
+bool ContainsNodeWithName(const string& name, const GraphDef& graph);
+
+// Checks whether the graph contains a node with the given op.
+bool ContainsNodeWithOp(const string& op, const GraphDef& graph);
+
+// Deletes nodes from the graph.
+Status DeleteNodes(const std::set<string>& nodes_to_delete, GraphDef* graph);
+
+// Returns the index of the node with the given name or -1 if the node does
+// not exist.
+int FindNodeWithName(const string& name, const GraphDef& graph);
+
+// Returns the index of a node with the given op or -1 if no such node
+// exists.
+int FindNodeWithOp(const string& op, const GraphDef& graph);
+
+} // end namespace graph_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
new file mode 100644
index 0000000000..b34726044e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_utils {
+namespace {
+
+class GraphUtilsTest : public ::testing::Test {};
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
+ GraphDef graph;
+ NodeDef* bool_node;
+ TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node));
+ EXPECT_TRUE(ContainsNodeWithName(bool_node->name(), graph));
+ EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
+ GraphDef graph;
+ NodeDef* double_node;
+ TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node));
+ EXPECT_TRUE(ContainsNodeWithName(double_node->name(), graph));
+ EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
+ GraphDef graph;
+ NodeDef* float_node;
+ TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node));
+ EXPECT_TRUE(ContainsNodeWithName(float_node->name(), graph));
+ EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
+ GraphDef graph;
+ NodeDef* int_node;
+ TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node));
+ EXPECT_TRUE(ContainsNodeWithName(int_node->name(), graph));
+ EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
+ GraphDef graph;
+ NodeDef* int64_node;
+ TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node));
+ EXPECT_TRUE(ContainsNodeWithName(int64_node->name(), graph));
+ EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
+}
+
+TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
+ GraphDef graph;
+ NodeDef* string_node;
+ TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node));
+ EXPECT_TRUE(ContainsNodeWithName(string_node->name(), graph));
+ EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
+}
+
+TEST_F(GraphUtilsTest, Compare) {
+ GraphDef graphA;
+ GraphDef graphB;
+ EXPECT_TRUE(Compare(graphA, graphB));
+
+ NodeDef* nodeA;
+ TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graphA, &nodeA));
+ NodeDef* nodeB;
+ TF_EXPECT_OK(AddNode("B", "OpB", {"A"}, {}, &graphA, &nodeB));
+ EXPECT_FALSE(Compare(graphA, graphB));
+
+ graphB.mutable_node()->CopyFrom(graphA.node());
+ EXPECT_TRUE(Compare(graphA, graphB));
+}
+
+TEST_F(GraphUtilsTest, ContainsNodeWithName) {
+ GraphDef graph;
+ EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+
+ NodeDef* node;
+ TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+ EXPECT_TRUE(ContainsNodeWithName("A", graph));
+
+ TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+ EXPECT_TRUE(!ContainsNodeWithName("A", graph));
+}
+
+TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
+ GraphDef graph;
+ EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+
+ NodeDef* node;
+ TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+ EXPECT_TRUE(ContainsNodeWithOp("OpA", graph));
+
+ TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+ EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
+}
+
+TEST_F(GraphUtilsTest, FindNodeWithName) {
+ GraphDef graph;
+ EXPECT_EQ(FindNodeWithName("A", graph), -1);
+
+ NodeDef* node;
+ TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+ EXPECT_NE(FindNodeWithName("A", graph), -1);
+
+ TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+ EXPECT_EQ(FindNodeWithName("A", graph), -1);
+}
+
+TEST_F(GraphUtilsTest, FindNodeWithOp) {
+ GraphDef graph;
+ EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+
+ NodeDef* node;
+ TF_EXPECT_OK(AddNode("A", "OpA", {}, {}, &graph, &node));
+ EXPECT_NE(FindNodeWithOp("OpA", graph), -1);
+
+ TF_EXPECT_OK(DeleteNodes({"A"}, &graph));
+ EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
+}
+
+} // namespace
+} // namespace graph_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
new file mode 100644
index 0000000000..5b8df61c48
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ std::set<string> nodes_to_delete;
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "BatchDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we now the node type.
+ NodeDef batch_node(node);
+ GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
+ NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we now the node type.
+ NodeDef* map_node = node2;
+ NodeDef* new_node = output->mutable_node()->Add();
+ new_node->set_op("MapAndBatchDatasetV2");
+ new_node->set_name(
+ strings::StrCat("MapAndBatchDatasetV2/_", output->node_size()));
+
+ // Set the `input` input argument.
+ new_node->add_input(map_node->input(0));
+
+ // Set the `other_arguments` input arguments.
+ int num_other_args;
+ if (map_node->op() == "ParallelMapDataset") {
+ num_other_args = map_node->input_size() - 2;
+ } else {
+ num_other_args = map_node->input_size() - 1;
+ }
+ for (int i = 0; i < num_other_args; i++) {
+ new_node->add_input(map_node->input(i + 1));
+ }
+
+ // Set the `batch_size` input argument.
+ new_node->add_input(batch_node.input(1));
+
+ // Set the `num_parallel_calls` input argument.
+ if (map_node->op() == "ParallelMapDataset") {
+ // The type of the `num_parallel_calls` argument in ParallelMapDataset
+ // and MapAndBatchDataset is different (int32 and int64 respectively)
+ // so we cannot reuse the same Const node and thus create a new one.
+ NodeDef* v = graph.GetNode(map_node->input(map_node->input_size() - 1));
+ NodeDef* tmp;
+ TF_RETURN_IF_ERROR(graph_utils::AddScalarConstNode<int64>(
+ v->attr().at("value").tensor().int_val(0), output, &tmp));
+ new_node->add_input(tmp->name());
+ } else {
+ NodeDef* tmp;
+ TF_RETURN_IF_ERROR(
+ graph_utils::AddScalarConstNode<int64>(1, output, &tmp));
+ new_node->add_input(tmp->name());
+ }
+
+ // Set the `drop_remainder` input argument.
+ {
+ NodeDef* tmp;
+ TF_RETURN_IF_ERROR(
+ graph_utils::AddScalarConstNode<bool>(false, output, &tmp));
+ new_node->add_input(tmp->name());
+ }
+
+ // Set `f` and `Targuments` attributes.
+ new_node->mutable_attr()->insert(map_node->attr().begin(),
+ map_node->attr().end());
+ // Set `output_types` and `output_shapes` attributes.
+ new_node->mutable_attr()->insert(batch_node.attr().begin(),
+ batch_node.attr().end());
+
+ // Mark the `Map` and `Batch` nodes for removal.
+ nodes_to_delete.insert(map_node->name());
+ nodes_to_delete.insert(batch_node.name());
+
+ // Update the input of the outputs of the `Batch` node to use
+ // `MapAndBatch`.
+ GraphView::OutputPort output_port =
+ graph.GetOutputPort(batch_node.name(), 0);
+ auto fanout = graph.GetFanout(output_port);
+ for (auto it = fanout.begin(); it != fanout.end(); ++it) {
+ NodeDef* node = it->node;
+ node->set_input(0, new_node->name());
+ }
+ }
+ TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+ return Status::OK();
+}
+
+void MapAndBatchFusion::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchFusion, "map_and_batch_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
new file mode 100644
index 0000000000..a5a4d91df6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class MapAndBatchFusion : public CustomGraphOptimizer {
+ public:
+ MapAndBatchFusion() {}
+ ~MapAndBatchFusion() override {}
+
+ string name() const override { return "map_and_batch_fusion"; };
+
+ Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
+ nullptr) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
new file mode 100644
index 0000000000..51e7f37e7e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h"
+
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
+ std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ empty_attributes, graph, &range_node));
+ NodeDef *captured_input_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+ "hello", graph, &captured_input_node));
+
+ std::vector<string> map_inputs(2);
+ map_inputs[0] = range_node->name();
+ map_inputs[1] = captured_input_node->name();
+ NodeDef *map_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs,
+ empty_attributes, graph, &map_node));
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ std::vector<string> batch_inputs(2);
+ batch_inputs[0] = map_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ NodeDef *batch_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ empty_attributes, graph, &batch_node));
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node =
+ output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_EQ(map_and_batch_node.input_size(), 5);
+ EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+ EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+ EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+ NodeDef num_parallel_calls_node = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
+ 1);
+ NodeDef drop_remainder_node = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+ EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
+}
+
+TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
+ std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ empty_attributes, graph, &range_node));
+ NodeDef *captured_input_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+ "hello", graph, &captured_input_node));
+ NodeDef *num_parallel_calls_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int>(2, graph, &num_parallel_calls_node));
+
+ std::vector<string> map_inputs(3);
+ map_inputs[0] = range_node->name();
+ map_inputs[1] = captured_input_node->name();
+ map_inputs[2] = num_parallel_calls_node->name();
+ NodeDef *map_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs,
+ empty_attributes, graph, &map_node));
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ std::vector<string> batch_inputs(2);
+ batch_inputs[0] = map_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ NodeDef *batch_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ empty_attributes, graph, &batch_node));
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node =
+ output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_EQ(map_and_batch_node.input_size(), 5);
+ EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+ EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+ EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+ NodeDef num_parallel_calls_node2 = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ EXPECT_EQ(num_parallel_calls_node2.attr().at("value").tensor().int64_val(0),
+ 2);
+ NodeDef drop_remainder_node = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(4), output));
+ EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false);
+}
+
+TEST(MapAndBatchFusionTest, NoChange) {
+ std::vector<std::pair<string, AttrValue>> empty_attributes;
+
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ empty_attributes, graph, &range_node));
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_TRUE(graph_utils::Compare(*graph, output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 611d871eea..fa228c68a1 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -610,6 +610,9 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
// Turn input placeholders into identity nodes.
CHECK_EQ(0, func_body_node.input_size());
func_body_node.set_op("Identity");
+ (*func_body_node.mutable_attr())["T"] = func_body_node.attr().at("dtype");
+ func_body_node.mutable_attr()->erase("dtype");
+ func_body_node.mutable_attr()->erase("shape");
int input_idx = input_placeholders_idx[func_body_node.name()];
func_body_node.add_input(
strings::StrCat(func_inputs->name(), ":", input_idx));
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index a92727535d..e6622486eb 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
+#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/grappler/utils/functions.h"
@@ -88,6 +89,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization()));
MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
MK_OPT("debug_stripper", new DebugStripper());
+ MK_OPT("scoped_allocator",
+ new ScopedAllocatorOptimizer(cfg_.scoped_allocator_opts()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -145,6 +148,10 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->emplace_back(
new AutoParallel(cfg_.auto_parallel().num_replicas()));
}
+ if (cfg_.scoped_allocator_optimization()) {
+ optimizers->emplace_back(
+ new ScopedAllocatorOptimizer(cfg_.scoped_allocator_opts()));
+ }
return Status::OK();
}
@@ -211,12 +218,32 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
bool is_optimized = false;
GraphOptimizationResult optimization_result(item.id);
+ // ScopedAllocatorOptimizer must run last, so move it to the
+ // end of optimizers and run only on the last iteration.
+ {
+ int sa_index = 0;
+ for (; sa_index < optimizers.size(); ++sa_index) {
+ if (optimizers[sa_index]->name() == "scoped_allocator_optimizer") {
+ break;
+ }
+ }
+ const int last_index = optimizers.size() - 1;
+ if (sa_index < last_index) {
+ optimizers[last_index].swap(optimizers[sa_index]);
+ }
+ }
+
+ const int last_iteration = NumIterations(cfg_) - 1;
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
VLOG(4) << "Starting optimization iteration " << iteration + 1;
for (const auto& optimizer : optimizers) {
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
+ // Some must run only on the last iteration.
+ if (optimizer->name() == "scoped_allocator_optimizer" &&
+ iteration != last_iteration)
+ continue;
uint64 start_us = Env::Default()->NowMicros();
// This swaps the current optimized_graph into optimized item and
@@ -361,6 +388,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.auto_parallel().enable() ||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
+ cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 2a62871293..efd870b118 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -28,10 +28,71 @@ namespace grappler {
void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
const string& x = fused_node.input(0);
- const string& scale = fused_node.input(1);
- const string& offset = fused_node.input(2);
- const string& mean = fused_node.input(3);
- const string& variance = fused_node.input(4);
+ string scale = fused_node.input(1);
+ string offset = fused_node.input(2);
+ string mean = fused_node.input(3);
+ string variance = fused_node.input(4);
+
+ if (fused_node.attr().at("data_format").s() == "NCHW") {
+ // Need to reshape the last 4 inputs
+ NodeDef* new_shape = optimized_graph->add_node();
+ new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name()));
+ new_shape->set_op("Const");
+ new_shape->set_device(fused_node.device());
+ *new_shape->add_input() = AsControlDependency(scale);
+ (*new_shape->mutable_attr())["dtype"].set_type(DT_INT32);
+ Tensor t(DT_INT32, {4});
+ t.flat<int32>()(0) = 1;
+ t.flat<int32>()(1) = -1;
+ t.flat<int32>()(2) = 1;
+ t.flat<int32>()(3) = 1;
+ t.AsProtoTensorContent(
+ (*new_shape->mutable_attr())["value"].mutable_tensor());
+
+ NodeDef* reshaped_scale = optimized_graph->add_node();
+ reshaped_scale->set_name(
+ AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
+ reshaped_scale->set_op("Reshape");
+ reshaped_scale->set_device(fused_node.device());
+ *reshaped_scale->add_input() = scale;
+ *reshaped_scale->add_input() = new_shape->name();
+ (*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32);
+ scale = reshaped_scale->name();
+
+ NodeDef* reshaped_offset = optimized_graph->add_node();
+ reshaped_offset->set_name(
+ AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
+ reshaped_offset->set_op("Reshape");
+ reshaped_offset->set_device(fused_node.device());
+ *reshaped_offset->add_input() = offset;
+ *reshaped_offset->add_input() = new_shape->name();
+ (*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32);
+ offset = reshaped_offset->name();
+
+ NodeDef* reshaped_mean = optimized_graph->add_node();
+ reshaped_mean->set_name(
+ AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
+ reshaped_mean->set_op("Reshape");
+ reshaped_mean->set_device(fused_node.device());
+ *reshaped_mean->add_input() = mean;
+ *reshaped_mean->add_input() = new_shape->name();
+ (*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32);
+ mean = reshaped_mean->name();
+
+ NodeDef* reshaped_variance = optimized_graph->add_node();
+ reshaped_variance->set_name(
+ AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
+ reshaped_variance->set_op("Reshape");
+ reshaped_variance->set_device(fused_node.device());
+ *reshaped_variance->add_input() = variance;
+ *reshaped_variance->add_input() = new_shape->name();
+ (*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T");
+ (*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32);
+ variance = reshaped_variance->name();
+ }
float epsilon = 0.0f;
if (fused_node.attr().count("epsilon")) {
@@ -118,20 +179,16 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
optimizable &= (node.attr().count("is_training") == 0 ||
!node.attr().at("is_training").b());
if (optimizable) {
- std::unordered_set<int> const_inputs;
- for (const string& input : node.input()) {
- int pos;
- const string input_node = ParseNodeName(input, &pos);
- if (properties.HasInputProperties(input_node)) {
- const auto& props = properties.GetInputProperties(input_node);
- if (props.size() > pos && props[pos].has_value()) {
- const_inputs.insert(pos);
- }
+ int const_inputs = 0;
+ const auto& props = properties.GetInputProperties(node.name());
+ for (const auto& prop : props) {
+ if (prop.has_value()) {
+ const_inputs += 1;
}
}
// TODO(bsteiner): use the cost model to compare the cost of fused batch
// norm against that of the optimized form.
- optimizable = (const_inputs.size() >= 4);
+ optimizable = (const_inputs >= 4);
}
if (optimizable) {
for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) {
@@ -143,6 +200,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
}
}
if (optimizable) {
+ std::cout << "Optimizing fused batch norm node " << node.DebugString()
+ << std::endl;
AddBatchNormNodes(optimized_graph, node);
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc
index 291585c538..4cbf0d8d6f 100644
--- a/tensorflow/core/grappler/optimizers/remapper_test.cc
+++ b/tensorflow/core/grappler/optimizers/remapper_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/platform/test.h"
@@ -54,5 +55,41 @@ TEST_F(RemapperTest, FusedBatchNorm) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(RemapperTest, FusedBatchNormNCHW) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output dflt =
+ ops::Const(s.WithOpName("dflt"), {3.14f, 2.7f, 1.0f, 2.0f, 3.0f, 100.0f},
+ {1, 3, 1, 2});
+ Output x = ops::PlaceholderWithDefault(s.WithOpName("x"), dflt, {1, 3, 1, 2});
+ Output scale = ops::Const(s.WithOpName("scale"), {0.3f, 7.0f, 123.0f}, {3});
+ Output offset =
+ ops::Const(s.WithOpName("offset"), {0.123f, 2.1f, 0.55f}, {3});
+ Output mean = ops::Const(s.WithOpName("mean"), {7.3f, 8.3f, 3.1f}, {3});
+ Output variance =
+ ops::Const(s.WithOpName("variance"), {0.57f, 1.0f, 2.0f}, {3});
+ ops::FusedBatchNorm::Attrs attr;
+ attr = attr.IsTraining(false);
+ attr = attr.DataFormat("NCHW");
+ ops::FusedBatchNorm bn(s.WithOpName("batch_norm").WithDevice("/device:GPU:0"),
+ x, scale, offset, mean, variance, attr);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"batch_norm"};
+
+ Remapper optimizer(RewriterConfig::ON);
+ GraphDef output;
+ TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
+
+ if (GetNumAvailableGPUs() > 0) {
+ // NCHW batch norm is only supported on GPU.
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ }
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
new file mode 100644
index 0000000000..cceef4098d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -0,0 +1,929 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
+
+#include "tensorflow/core/common_runtime/scoped_allocator.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/frame.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+// Like TF_RETURN_IF_ERROR, but also logs a WARNING.
+#define LOG_WARNING_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ LOG(WARNING) << "error: " << _status; \
+ return _status; \
+ } \
+ } while (0)
+
+namespace tensorflow {
+namespace grappler {
+
+namespace {
+// Node names often have some kind of name_scope prefix, with slashes,
+// and a _nn numeric suffix. Returns true if the main part of the node_name
+// matches op_name, i.e. it looks from the name like this node is
+// of that op type.
+bool HasOpName(const string& node_name, const string& op_name) {
+ size_t begin = node_name.rfind("/");
+ if (begin == string::npos) {
+ begin = 0;
+ } else {
+ ++begin;
+ }
+ size_t end = node_name.rfind("_");
+ if (end != string::npos) {
+ size_t p = end + 1;
+ while (p < node_name.size()) {
+ if (!isdigit(node_name[p])) {
+ end = node_name.size();
+ break;
+ }
+ ++p;
+ }
+ } else {
+ end = node_name.size();
+ }
+ return node_name.substr(begin, end - begin) == op_name;
+}
+
+// After shape inference has been done each op should be annotated
+// with its output shape(s). This function iterates over a collection
+// of ops that are a potential application of a ScopedAllocator. It
+// verifies whether they all have the same output type and if so
+// gathers a vector of their output shapes. It returns an error if
+// any of the ops doesn't have type or shape data, or if it has more
+// than one output, of if the output type of all ops is not the same.
+// If it returns OK then *type and *shapes should be correctly populated.
+Status CheckTypesAndGetShapes(const GraphProperties& graph_properties,
+ const std::vector<NodeDef*>& ops, DataType* type,
+ std::vector<TensorShape>* shapes) {
+ VLOG(1) << "CheckTypesAndGetShapes";
+ *type = DT_INVALID;
+ for (NodeDef* n : ops) {
+ AttrSlice n_attrs = AttrSlice(*n);
+ DataType dtype;
+ LOG_WARNING_AND_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
+ VLOG(2) << "op " << n->name() << " has type " << dtype << " shapes.size() "
+ << shapes->size();
+ if (!graph_properties.HasOutputProperties(n->name())) {
+ LOG(ERROR) << "Node " << n->DebugString() << " lacks output shape.";
+ return errors::Internal("Node ", n->name(), " lacks output shape.");
+ }
+ const std::vector<OpInfo::TensorProperties>& prop_list =
+ graph_properties.GetOutputProperties(n->name());
+ if (prop_list.size() != 1) {
+ return errors::Internal("Node ", n->name(),
+ " does not have exactly one output as expected "
+ "by ScopedAllocatorOptimizer");
+ }
+ const OpInfo::TensorProperties& props = prop_list[0];
+ if (shapes->empty()) {
+ *type = props.dtype();
+ } else if (*type != props.dtype()) {
+ return errors::Internal("Group ops don't all have same type");
+ } else if (!TensorShape::IsValid(props.shape())) {
+ return errors::Internal("Complete shape not known for ", n->name());
+ }
+ VLOG(2) << "Adding shape " << props.shape().DebugString();
+ shapes->push_back(TensorShape(props.shape()));
+ }
+ return Status::OK();
+}
+
+// Describes an existing input edge in the graph.
+struct InputDesc {
+ NodeDef* from_node_def;
+ int output_slot;
+ NodeDef* to_node_def;
+ InputDesc(NodeDef* f, int os, NodeDef* t)
+ : from_node_def(f), output_slot(os), to_node_def(t) {}
+};
+
+// Populates *inputs with all of the non-control inputs of ops.
+// Returns error if it fails to find exactly one input for each op,
+// or if some input is not of type dtype.
+Status GetInputs(NodeMap* node_map, const std::vector<NodeDef*>& ops,
+ DataType dtype, std::vector<InputDesc>* inputs) {
+ VLOG(1) << "Getinputs";
+ for (NodeDef* n : ops) {
+ NodeDef* inode = nullptr;
+ int position = 0;
+ VLOG(2) << "for node " << n->name();
+ for (const auto& input_name : n->input()) {
+ if (!IsControlInput(input_name)) {
+ if (inode) {
+ return errors::Internal("Found more than one input for node ",
+ n->name());
+ }
+ ParseNodeName(input_name, &position);
+ inode = node_map->GetNode(input_name);
+ CHECK(inode) << input_name;
+ VLOG(2) << "inode " << inode->DebugString();
+ }
+ }
+ AttrSlice inode_attrs = AttrSlice(*inode);
+ DataType inode_dtype;
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ GetNodeAttr(inode_attrs, "T", &inode_dtype));
+ if (inode_dtype != dtype) {
+ return errors::Internal("ScopedAllocatorOptimizer expected input type ",
+ dtype, " but found ", inode_dtype);
+ }
+ // inputs->push_back(InputDesc(inode, position, n));
+ inputs->emplace_back(inode, position, n);
+ }
+ return Status::OK();
+}
+
+// Remove the NodeDef nd from node_map and graph. It must be the case
+// that nd no longer has any input or output edges, though that is not
+// checked.
+void RemoveNode(NodeDef* nd, GraphDef* graph, NodeMap* node_map) {
+ node_map->RemoveNode(nd->name());
+ // TODO(tucker): The efficiency of this routine is poor.
+ // Change to accumulate and do a bulk removal, maybe refactoring
+ // some code from dependency_optimizer.
+ protobuf::RepeatedPtrField<NodeDef>* nodes = graph->mutable_node();
+ for (int i = 0; i < nodes->size(); ++i) {
+ if (nd->name() == (*nodes)[i].name()) {
+ nodes->SwapElements(i, nodes->size() - 1);
+ nodes->RemoveLast();
+ return;
+ }
+ }
+ LOG(FATAL) << "Failed to find node " << nd->name() << " in graph";
+}
+
+// Removes a named edge from between two nodes.
+Status RemoveEdge(const string& input_edge_name, const string& from_node_name,
+ NodeDef* to_node, NodeMap* node_map) {
+ if (node_map) {
+ node_map->RemoveOutput(from_node_name, to_node->name());
+ }
+ protobuf::RepeatedPtrField<string>* inputs = to_node->mutable_input();
+ int edge_index = -1;
+ for (edge_index = 0; edge_index < inputs->size(); ++edge_index) {
+ VLOG(2) << " consider edge " << (*inputs)[edge_index];
+ if ((*inputs)[edge_index] == input_edge_name) {
+ break;
+ }
+ }
+ if (edge_index >= inputs->size()) {
+ return errors::Internal("Could not find input name ", input_edge_name,
+ " at node ", to_node->name());
+ }
+ inputs->DeleteSubrange(edge_index, 1);
+ return Status::OK();
+}
+} // namespace
+
+void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
+ const std::vector<int32>& values,
+ NodeDef* node_def) {
+ if (HasNodeAttr(*node_def, name)) {
+ VLOG(2) << "extending";
+ AttrValue* existing = &(*node_def->mutable_attr())[name.ToString()];
+ for (int32 i : values) {
+ existing->mutable_list()->add_i(i);
+ }
+ } else {
+ VLOG(2) << "setting new attr value";
+ AddNodeAttr(name, values, node_def);
+ }
+}
+
+class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
+ public:
+ ~UnaryElementwiseRewriter() override {}
+
+ // Return non-OK if any input is already committed to a ScopedAllocator.
+ Status CheckExistingScopedAllocator(const std::vector<InputDesc>& inputs) {
+ for (const InputDesc& nd : inputs) {
+ VLOG(2) << "get attrs for " << nd.from_node_def->name();
+ AttrSlice n_attrs = AttrSlice(*nd.from_node_def);
+ int sa_id;
+ Status ss = GetNodeAttr(n_attrs, "sa_id", &sa_id);
+ if (ss.ok()) {
+ LOG(INFO) << "Abandoning PARewriter because input "
+ << nd.from_node_def->name() << " is already assigned "
+ << "to ScopedAllocator " << sa_id;
+ return errors::Internal(
+ "Abandoning PARewriter because input ", nd.from_node_def->name(),
+ " is already assigned to ScopedAllocator ", sa_id);
+ }
+ }
+ return Status::OK();
+ }
+
+ // Return non-OK if any input is a member of op_set.
+ Status CheckInternalDataDependency(const std::set<string>& op_set,
+ const std::vector<InputDesc>& inputs) {
+ for (const InputDesc& nd : inputs) {
+ if (op_set.find(nd.from_node_def->name()) != op_set.end()) {
+ if (nd.output_slot != tensorflow::Graph::kControlSlot) {
+ return errors::Internal("Data edge exists bewtween ",
+ nd.from_node_def->name(),
+ " and another "
+ "node in the set");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ // Remove all control edges between members of ops.
+ void ClearInternalControlInputs(const std::set<string>& op_set,
+ const std::vector<NodeDef*>& ops,
+ NodeMap* node_map) {
+ for (NodeDef* n : ops) {
+ for (const auto& input_name : n->input()) {
+ if (IsControlInput(input_name)) {
+ int position = 0;
+ string input_node_name = ParseNodeName(input_name, &position);
+ CHECK_EQ(position, -1);
+ if (op_set.find(input_node_name) != op_set.end()) {
+ // This is an internal control edge. Remove it.
+ VLOG(1) << "Remove control output from " << input_node_name
+ << " via edge " << input_name << " to " << n->name();
+ TF_CHECK_OK(RemoveEdge(input_name, input_node_name, n, node_map));
+ }
+ }
+ }
+ }
+ }
+
+ // Examine the input set of an op set, gathering their shapes and types
+ // and checking whether there are any considerations that prevent use
+ // of a single ScopedAllocator for all of those inputs.
+ Status AnalyzeInputs(ScopedAllocatorOptimizer* sa_opti, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops,
+ const std::set<string>& op_instance_names,
+ string* device_name, DataType* dtype,
+ std::vector<TensorShape>* input_shapes,
+ std::vector<InputDesc>* inputs, TensorShape* sa_shape) {
+ CHECK(graph_properties_);
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ CheckTypesAndGetShapes(*graph_properties_, ops, dtype, input_shapes));
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ GetInputs(sa_opti->node_map(), ops, *dtype, inputs));
+ LOG_WARNING_AND_RETURN_IF_ERROR(CheckExistingScopedAllocator(*inputs));
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ CheckInternalDataDependency(op_instance_names, *inputs));
+ ClearInternalControlInputs(op_instance_names, ops, node_map);
+ *device_name = ops[0]->device();
+ CHECK(!device_name->empty());
+ CHECK(!input_shapes->empty());
+ CHECK_EQ(0, Allocator::kAllocatorAlignment % DataTypeSize(*dtype))
+ << "ScopedAllocatorOptimizer only applies to types that evenly "
+ << "divide kAllocatorAlignment";
+ std::vector<ScopedAllocator::Field> sa_fields;
+ // Calculate the field embedding boundaries and thereby the
+ // required size of the backing tensor.
+ int64 num_bytes = ScopedAllocatorMgr::PopulateFields(
+ 0 /*scope_id*/, *input_shapes, *dtype, &sa_fields);
+ int64 num_elts = num_bytes / DataTypeSize(*dtype);
+ VLOG(2) << "num_bytes " << num_bytes << " num_elts=" << num_elts;
+ *sa_shape = TensorShape({num_elts});
+ return Status::OK();
+ }
+
+ // Build the ScopedAllocator node that will be assigned to allocate
+ // the output tensors of the input node set.
+ Status ConstructScopedAllocatorNode(
+ ScopedAllocatorOptimizer* sa_opti, GraphDef* graph, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops, const string& device_name,
+ DataType dtype, int sa_id, const string& sa_name,
+ const std::vector<TensorShape>& input_shapes,
+ const std::vector<InputDesc>& inputs, const TensorShape& sa_shape) {
+ VLOG(2) << "ConstructScopedAllocatorNode " << sa_name;
+ NodeDefBuilder sa_builder(sa_name, "_ScopedAllocator");
+ sa_builder.Device(device_name);
+ sa_builder.Attr("sa_name", sa_name);
+ sa_builder.Attr("T", dtype);
+ sa_builder.Attr("id", sa_id);
+ sa_builder.Attr("shapes", input_shapes);
+ sa_builder.Attr("shape", sa_shape);
+ sa_builder.Attr("expected_call_count", static_cast<int64>(ops.size()));
+ NodeDef* sa_node = graph->add_node();
+ LOG_WARNING_AND_RETURN_IF_ERROR(sa_builder.Finalize(sa_node));
+ node_map->AddNode(sa_name, sa_node);
+
+ // Add control edges from the ScopedAllocatorOp to all of the
+ // input nodes and mark them for allocation from backing tensor.
+ for (int i = 0; i < inputs.size(); ++i) {
+ auto& nd = inputs[i];
+ VLOG(2) << "To input " << i << ": " << nd.from_node_def->name()
+ << " add control input "
+ << "^" << sa_name;
+ nd.from_node_def->add_input(strings::StrCat("^", sa_name));
+ // This attribute says: allocate output_slot from
+ // ScopedAllocator instance sa_id + 1 + i.
+ ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator",
+ {nd.output_slot, sa_id + 1 + i},
+ nd.from_node_def);
+ node_map->AddOutput(sa_name, nd.from_node_def->name());
+ }
+ return Status::OK();
+ }
+
+ Status BuildSAConcatNode(GraphDef* graph, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops,
+ const std::set<string>& op_instance_names,
+ const string& device_name, DataType dtype, int sa_id,
+ const string& sa_name, const string& sac_name,
+ const TensorShape& sa_shape,
+ std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
+ VLOG(2) << "BuildSAConcatNode " << sac_name;
+ std::set<string> sac_ctl_inputs;
+ for (int i = 0; i < ops.size(); ++i) {
+ NodeDef* old_op = ops[i];
+ for (const string& old_op_input : old_op->input()) {
+ int position = 0;
+ string input_name = ParseNodeName(old_op_input, &position);
+ if (position == -1) {
+ // A control input: drop if from another member of the op set.
+ if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
+ sac_ctl_inputs.insert(old_op_input);
+ }
+ } else {
+ // TODO(tucker): remove redundant check.
+ // A data input: illegal if from another member of the op set.
+ if (op_instance_names.find(old_op_input) != op_instance_names.end()) {
+ LOG(ERROR) << "Data edge between " << old_op_input << " and "
+ << old_op->name() << " cannot build ScopedAllocator.";
+ return errors::Internal("Data edge between ", old_op_input, " and ",
+ old_op->name(),
+ " cannot build ScopedAllocator.");
+ }
+ sac_inputs->push_back(
+ NodeDefBuilder::NodeOut(old_op_input, 0, dtype));
+ }
+ VLOG(3) << "from op " << i << ": " << old_op->name()
+ << " sac_inputs append " << old_op_input;
+ }
+ }
+ NodeDefBuilder sac_builder(sac_name, "_ScopedAllocatorConcat");
+ VLOG(2) << "New sac_name " << sac_name << " shape "
+ << sa_shape.DebugString();
+ sac_builder.Device(device_name);
+ sac_builder.Attr("sa_name", sa_name);
+ sac_builder.Attr("id", sa_id);
+ sac_builder.Attr("T", dtype);
+ sac_builder.Attr("shape", sa_shape);
+ sac_builder.Attr("N", static_cast<int>(sac_inputs->size()));
+ sac_builder.Input(NodeDefBuilder::NodeOut(sa_name, 0, dtype));
+ sac_builder.Input(*sac_inputs);
+ NodeDef* sac_node = graph->add_node();
+ LOG_WARNING_AND_RETURN_IF_ERROR(sac_builder.Finalize(sac_node));
+ node_map->AddNode(sac_name, sac_node);
+ node_map->AddOutput(sa_name, sac_name);
+
+ // Attach the old control inputs to the new sac node.
+ for (const string& ctl_input : sac_ctl_inputs) {
+ sac_node->add_input(ctl_input);
+ }
+ return Status::OK();
+ }
+
+ Status BuildReplacementOp(GraphDef* graph, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops,
+ const string& device_name, DataType dtype,
+ const string& op_name, const string& sac_name,
+ const string& sa_op_name) {
+ VLOG(2) << "BuildReplacementOp " << sa_op_name;
+ NodeDefBuilder op_builder(sa_op_name, op_name);
+ op_builder.Device(device_name);
+
+ // Transfer the Node Attr from the first replaced Node to the new
+ // Node. TODO(tucker): In principle we should verify that
+ // the Attr are consistent and compatible across all op instances.
+ // Unfortunately that will probably require op-specific tests, so
+ // punt on that for the time being.
+ AttrSlice first_slice(*ops[0]);
+ for (auto& it : first_slice) {
+ op_builder.Attr(it.first, it.second);
+ }
+ op_builder.Attr("_forward_input", {0, 0});
+ op_builder.Input(sac_name, 0, dtype);
+ NodeDef* sa_op_node = graph->add_node();
+ LOG_WARNING_AND_RETURN_IF_ERROR(op_builder.Finalize(sa_op_node));
+ node_map->AddNode(sa_op_name, sa_op_node);
+ node_map->AddOutput(sac_name, sa_op_name);
+ return Status::OK();
+ }
+
+ Status BuildSplitNode(GraphDef* graph, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops,
+ const std::vector<TensorShape>& input_shapes,
+ const std::vector<NodeDefBuilder::NodeOut>& sac_inputs,
+ const string& device_name, DataType dtype,
+ const string& op_name, int sa_id,
+ const string& sas_name, const string& sa_name,
+ const string& sa_op_name) {
+ VLOG(2) << "new ScopedAllocatorSplit " << sas_name;
+ NodeDefBuilder sas_builder(sas_name, "_ScopedAllocatorSplit");
+ sas_builder.Device(device_name);
+ sas_builder.Attr("sa_name", sa_name);
+ sas_builder.Attr("id", sa_id);
+ sas_builder.Attr("T", dtype);
+ sas_builder.Attr("shapes", input_shapes);
+ std::vector<NodeDefBuilder::NodeOut> sas_inputs = sac_inputs;
+ sas_builder.Attr("N", static_cast<int>(sas_inputs.size()));
+ sas_builder.Input(NodeDefBuilder::NodeOut(sa_op_name, 0, dtype));
+ sas_builder.Input(sas_inputs);
+ NodeDef* sas_node = graph->add_node();
+ LOG_WARNING_AND_RETURN_IF_ERROR(sas_builder.Finalize(sas_node));
+ node_map->AddNode(sas_name, sas_node);
+ node_map->AddOutput(sa_op_name, sas_name);
+ return Status::OK();
+ }
+
+ // After the new ScopedAllocator and its corresponding Concat and
+ // Split nodes have been built, and a new single Op instance
+ // constructed, rewire the graph: Remove input edges to the old Op
+ // nodes and replace the old Op node outputs with the corresponding
+ // ScopedAllocatorSplit node outputs. After this the old Op nodes
+ // should no longer have any input or output edges and they can be
+ // removed from the graph.
+ Status RewireSubgraph(GraphDef* graph, NodeMap* node_map,
+ const std::vector<NodeDef*>& ops,
+ const std::set<string>& op_instance_names,
+ const string& op_name, const string& sas_name) {
+ VLOG(2) << "RewireSubgraph";
+ for (int op_idx = 0; op_idx < ops.size(); ++op_idx) {
+ NodeDef* old_op = ops[op_idx];
+ // Copy the output node set since we'll be modifying the version
+ // maintained by NodeMap in the loop.
+ std::set<NodeDef*> output_nodes = node_map->GetOutputs(old_op->name());
+ VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
+ << " outputs. Moving them to the PASplit node.";
+ if (VLOG_IS_ON(2)) {
+ for (NodeDef* n : output_nodes) {
+ VLOG(3) << " output: " << n->name();
+ }
+ }
+ for (NodeDef* n : output_nodes) {
+ VLOG(3) << "really checking old output " << n->name()
+ << " for corresponding input.";
+ if (op_instance_names.find(n->name()) != op_instance_names.end()) {
+ // If this output node is a member of the ops set, it must have
+ // been an internal control edge so drop it.
+ VLOG(3) << "Dropping control output from " << old_op->name() << " to "
+ << n->name();
+ // However, we may already have dropped it at the clear() below,
+ // so if we fail to find it, that's okay.
+ Status ignore = RemoveEdge(strings::StrCat("^", old_op->name()),
+ old_op->name(), n, node_map);
+ continue;
+ }
+ bool found = false;
+ VLOG(3) << "about to iterate over " << n->input_size() << " inputs";
+ for (int i = 0; i < n->input_size(); ++i) {
+ VLOG(3) << "input " << n->input(i);
+ int position = 0;
+ string input_node = ParseNodeName(n->input(i), &position);
+ if (input_node == old_op->name()) {
+ found = true;
+ VLOG(3) << "match pos=" << position;
+ if (position == -1) {
+ // It was a control edge
+ *n->mutable_input(i) = strings::StrCat("^", sas_name);
+ } else {
+ CHECK_EQ(0, position)
+ << "name " << n->input(i) << " pos " << position;
+ *n->mutable_input(i) = strings::StrCat(sas_name, ":", op_idx);
+ }
+ node_map->RemoveOutput(old_op->name(), n->name());
+ node_map->AddOutput(sas_name, n->name());
+ VLOG(3) << "breaking on success";
+ break;
+ } else {
+ VLOG(3) << "other input " << n->input(i);
+ }
+ }
+ // In general it's required that we found the output node's old
+ // input and replaced it, but one exception is if the output node
+ // is of the same type being coalesced and the edge is a control
+ // input. In that case it probably got eliminated in an earlier
+ // pass.
+ VLOG(3) << "before HasOp";
+ if (!HasOpName(n->name(), op_name)) {
+ CHECK(found) << "old_op " << old_op->name() << " node "
+ << " could not find input edge on " << n->DebugString()
+ << " to replace."
+ << " " << op_name << " not in " << n->name();
+ }
+ VLOG(3) << "bottom of for output_nodes";
+ }
+ VLOG(3) << "Clearing all inputs of " << old_op->name();
+ node_map->RemoveInputs(old_op->name());
+ old_op->clear_input();
+ node_map->RemoveOutputs(old_op->name());
+ VLOG(3) << "after clear: " << old_op->DebugString();
+ // old_op should be dead, with no further inputs or outputs.
+ // It needs to be removed altogether before the graph is generated,
+ // but we need to leave it around until this Optimizer is done,
+ // because there may be some
+ // Remove.
+ RemoveNode(old_op, graph, node_map);
+ }
+ return Status::OK();
+ }
+
+ // Given a collection of instances of op_name, presumed to be
+ // logically parallel and operating on tensors of the same type,
+ // replace them by a single instance. First find the upstream Ops
+ // generating their inputs. Create a new ScopedAllocatorOp that
+ // outputs a single backing_tensor pre-arranged for sub-allocation
+ // of all of those input tensors. Then insert a new
+ // ScopedAllocatorConcatOp below the upstream Ops to make explicit
+ // the materialization of a concatenation of their outputs. Put the
+ // new op_name instance below the new concat op and follow with a
+ // ScopedAllocatorSplitOp that restores the correct shape outputs
+ // for the consumers of the old op_name instances.
+ //
+ // There must be no non-control edges between Nodes in 'ops'.
+ // Control edges among these nodes will be dropped.
+ Status Rewrite(ScopedAllocatorOptimizer* sa_opti, GraphDef* graph,
+ const string& op_name, const std::vector<NodeDef*>& ops,
+ bool* applied) override {
+ if (VLOG_IS_ON(1)) {
+ VLOG(1) << "Rewrite";
+ string op_names;
+ for (auto& nd : ops) {
+ strings::StrAppend(&op_names, nd->name(), ", ");
+ }
+ VLOG(1) << "UnaryElementwiseRewriter::Rewrite " << op_name
+ << " to: " << op_names;
+ }
+ NodeMap* node_map = sa_opti->node_map();
+
+ // Make a set of the node names for faster membership testing.
+ std::set<string> op_instance_names;
+ for (auto& nd : ops) {
+ op_instance_names.insert(nd->name());
+ VLOG(2) << "op_instance_name " << nd->name();
+ }
+ DataType dtype;
+ std::vector<TensorShape> input_shapes;
+ std::vector<InputDesc> inputs;
+ TensorShape sa_shape;
+ string device_name;
+
+ TF_RETURN_IF_ERROR(AnalyzeInputs(sa_opti, node_map, ops, op_instance_names,
+ &device_name, &dtype, &input_shapes,
+ &inputs, &sa_shape));
+
+ int sa_id = sa_opti->NewScopedAllocatorId(input_shapes.size());
+ string sa_name = strings::StrCat("scoped_allocator_", sa_id);
+ TF_RETURN_IF_ERROR(ConstructScopedAllocatorNode(
+ sa_opti, graph, node_map, ops, device_name, dtype, sa_id, sa_name,
+ input_shapes, inputs, sa_shape));
+
+ // TODO(tucker): Maybe add control edges to delay execution of the
+ // ScopedAllocatorOp until just before first use in order to
+ // conserve memory. What would be correct? Let I0...In be the
+ // input nodes that are all going to alloc from SA. If we make
+ // SA wait until all of these are ready, that might be too slow.
+ // It should probably wait until at least one is ready, but which
+ // one? Maybe just pick the first.
+ // {
+ // auto& nd = inputs[0];
+ // std::vector<InputDesc> inputs_to_first;
+ // LOG_WARNING_AND_RETURN_IF_ERROR(GetInputs(sa_opti->node_map(),
+ // {nd.from_node_def},
+ // dtype, &inputs_to_first));
+ // for (int i = 0; i < inputs_to_first.size(); ++i) {
+ // sa_node->add_input(
+ // strings::StrCat("^", inputs_to_first[i].from_node_def->name()));
+ // }
+ // }
+
+ // Build a ScopedAllocatorConcat below all of the input nodes.
+ std::vector<NodeDefBuilder::NodeOut> sac_inputs;
+ string sac_name = strings::StrCat("scoped_allocator_concat_", sa_id);
+ TF_RETURN_IF_ERROR(BuildSAConcatNode(
+ graph, node_map, ops, op_instance_names, device_name, dtype, sa_id,
+ sa_name, sac_name, sa_shape, &sac_inputs));
+
+ // Construct a new instance of the parallel op and insert it
+ // immediately below the new ScopedAllocatorConcat.
+ string sa_op_name = strings::StrCat(sa_name, "_", op_name);
+ TF_RETURN_IF_ERROR(BuildReplacementOp(graph, node_map, ops, device_name,
+ dtype, op_name, sac_name,
+ sa_op_name));
+
+ // Build a ScopedAllocatorSplit split below the new Op.
+ string sas_name = strings::StrCat("scoped_allocator_split_", sa_id);
+ TF_RETURN_IF_ERROR(BuildSplitNode(graph, node_map, ops, input_shapes,
+ sac_inputs, device_name, dtype, op_name,
+ sa_id, sas_name, sa_name, sa_op_name));
+
+ // Rewire the graph.
+ TF_RETURN_IF_ERROR(RewireSubgraph(graph, node_map, ops, op_instance_names,
+ op_name, sas_name));
+
+ *applied = true;
+ return Status::OK();
+ }
+};
+
+ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
+ const ScopedAllocatorOptions& opts) {
+ VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
+ Rewriter* r = new UnaryElementwiseRewriter();
+ to_delete_.push_back(r);
+ if (opts.enable_op_size() == 0) {
+ // Opts handled by default:
+ for (const auto& op_name : {"CollectiveReduce"}) {
+ op_name_set_.insert(op_name);
+ rewriters_[op_name] = r;
+ }
+ } else {
+ for (const auto& op_name : opts.enable_op()) {
+ op_name_set_.insert(op_name);
+ rewriters_[op_name] = r;
+ }
+ }
+}
+
+Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ // Nodes that cannot be removed from the graph without damaging correctness,
+ // typically fetch nodes.
+ nodes_to_preserve_ = item.NodesToPreserve();
+
+ GraphProperties graph_properties(item);
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ graph_properties.InferStatically(assume_valid_feeds));
+ node_map_.reset(new NodeMap(optimized_graph));
+
+ LOG_WARNING_AND_RETURN_IF_ERROR(ScopedAllocatorOptimizer::ProcessGraphDef(
+ optimized_graph, graph_properties));
+
+ VLOG(1) << "ScopedAllocatorOptimizer::Optimize() done";
+ return Status::OK();
+}
+
+ScopedAllocatorOptimizer::Rewriter* ScopedAllocatorOptimizer::GetRewriter(
+ const string& op_name) {
+ auto it = rewriters_.find(op_name);
+ if (it != rewriters_.end()) {
+ return it->second;
+ }
+ return nullptr;
+}
+
+int ScopedAllocatorOptimizer::NewScopedAllocatorId(int num_fields) {
+ CHECK_GT(num_fields, 0);
+ int id = next_sa_id_;
+ next_sa_id_ += (num_fields + 1);
+ CHECK_GT(next_sa_id_, 0);
+ return id;
+}
+
+ScopedAllocatorOptimizer::~ScopedAllocatorOptimizer() {
+ for (auto ptr : to_delete_) {
+ delete ptr;
+ }
+}
+
+void ScopedAllocatorOptimizer::FindOpOccurrences(GraphDef* graph,
+ const OpNameSet& op_names,
+ GraphOpOccurrences* occs) {
+ VLOG(1) << "FindOpOccurrences ";
+ for (const auto& it : op_names) {
+ VLOG(1) << "search target " << it;
+ }
+ for (int ni = 0; ni < graph->node_size(); ++ni) {
+ NodeDef* node = graph->mutable_node(ni);
+ const string& op_name = node->op();
+ if (op_names.find(op_name) != op_names.end()) {
+ VLOG(1) << "found " << op_name << " on dev " << node->device();
+ (*occs)[node->device()][op_name].push_back(node);
+ }
+ }
+}
+
+namespace {
+struct OpNameOrder {
+ bool operator()(const NodeDef* a, const NodeDef* b) {
+ return a->name() <= b->name();
+ }
+};
+
+class Tree {
+ public:
+ Tree(const string& edge, int depth) : edge_(edge), depth_(depth) {}
+ ~Tree() {
+ for (auto it : subtrees_) delete it.second;
+ }
+
+ Tree* GetSubTree(const string& edge) {
+ auto it = subtrees_.find(edge);
+ if (it != subtrees_.end()) {
+ return it->second;
+ }
+ Tree* t = new Tree(edge, depth_ + 1);
+ subtrees_[edge] = t;
+ return t;
+ }
+
+ void InsertNode(NodeDef* n) { nodes_.push_back(n); }
+
+ string edge_;
+ int depth_;
+ std::vector<NodeDef*> nodes_;
+ std::unordered_map<string, Tree*> subtrees_;
+};
+
+// Applies a function to every Tree in DFS order. Terminates early
+// on any non-OK Status.
+Status ApplyToAll(Tree* tree, const std::function<Status(Tree*)>& func) {
+ Status s;
+ for (auto it : tree->subtrees_) {
+ s = ApplyToAll(it.second, func);
+ if (!s.ok()) return s;
+ }
+ s = func(tree);
+ return s;
+}
+
+Tree* ComputeScopeTree(const string& op_name,
+ const std::vector<NodeDef*>& node_vec) {
+ Tree* root = new Tree("", 0);
+ for (NodeDef* n : node_vec) {
+ std::vector<string> pieces = str_util::Split(n->name(), "/");
+ // last piece is node name proper.
+ int depth = pieces.size() - 1;
+ Tree* subtree = root;
+ for (int i = 0; i < depth; ++i) {
+ subtree = subtree->GetSubTree(pieces[i]);
+ }
+ subtree->InsertNode(n);
+ }
+ return root;
+}
+
+void PartitionByLoopStructure(const FrameMap& frame_map,
+ std::vector<NodeDef*> nodes,
+ std::vector<std::vector<NodeDef*>>* loop_groups) {
+ // It is assumed that two nodes with identical loop containment have
+ // identical integer vectors. Represent those by 64 bit hashes.
+ std::unordered_map<uint64, std::vector<NodeDef*>> loop_sets;
+ for (NodeDef* nd : nodes) {
+ uint64 hash = 0;
+ const auto& it = frame_map.find(nd);
+ if (it != frame_map.end()) {
+ const std::vector<int>& loop_ids = it->second;
+ for (int id : loop_ids) {
+ hash = Hash64Combine(hash, static_cast<uint64>(id));
+ }
+ }
+ loop_sets[hash].push_back(nd);
+ }
+ for (auto it : loop_sets) {
+ loop_groups->push_back(std::move(it.second));
+ }
+}
+
+} // namespace
+
+Status ScopedAllocatorOptimizer::ProcessGraphDef(
+ GraphDef* graph, const GraphProperties& graph_properties) {
+ VLOG(1) << "ProcessGraphDef";
+ Status status;
+ GraphOpOccurrences occ;
+ FindOpOccurrences(graph, op_name_set_, &occ);
+ if (!occ.empty()) {
+ FrameMap frame_map;
+ int num_frames;
+ LOG_WARNING_AND_RETURN_IF_ERROR(
+ IdentifyFramesWithNodeMap(*graph, *node_map_, &frame_map, &num_frames));
+ for (auto& dt : occ) {
+ VLOG(2) << "Processing device " << dt.first;
+ const DevOpOccurrences& dev_occ = dt.second;
+ for (auto& it : dev_occ) {
+ string op_name = it.first;
+ VLOG(1) << "Processing " << op_name << " set size " << it.second.size();
+ Rewriter* rewriter = GetRewriter(op_name);
+ if (!rewriter) {
+ LOG(ERROR) << "Failed to find PARewriter for op_name " << op_name;
+ continue;
+ }
+ rewriter->SetGraphProperties(graph_properties);
+ std::unique_ptr<Tree> root(ComputeScopeTree(it.first, it.second));
+ // Nodes with a common depth and root path are now grouped
+ // in the same Tree struct. Split those groups into subgroups that
+ // share identical loop nesting.
+ status = ApplyToAll(
+ root.get(), [this, rewriter, graph, &frame_map, &op_name](Tree* t) {
+ VLOG(2) << "applied to tree node " << t->edge_ << " at depth "
+ << t->depth_ << " of size " << t->nodes_.size();
+ if (t->nodes_.size() > 1) {
+ std::vector<std::vector<NodeDef*>> loop_groups;
+ PartitionByLoopStructure(frame_map, t->nodes_, &loop_groups);
+ for (auto& lg : loop_groups) {
+ if (lg.size() > 1) {
+ bool applied = false;
+ Status s = OrderNodeSet(&lg);
+ TF_RETURN_IF_ERROR(s);
+ VLOG(1) << "Applying Rewriter for " << op_name;
+ s = rewriter->Rewrite(this, graph, op_name, lg, &applied);
+ LOG_WARNING_AND_RETURN_IF_ERROR(s);
+ }
+ }
+ }
+ return Status::OK();
+ });
+ if (!status.ok()) {
+ break;
+ }
+ }
+ if (!status.ok()) {
+ break;
+ }
+ }
+ }
+ VLOG(1) << "ScopedAllocatorOptimizer returning " << status;
+ if (!status.ok()) {
+ LOG(ERROR) << "ScopedAllocatorOptimizer: " << status;
+ }
+ return status;
+}
+
+namespace {
+struct InstanceKeyLess {
+ bool operator()(const NodeDef* a, const NodeDef* b) const {
+ AttrSlice a_attrs = AttrSlice(*a);
+ AttrSlice b_attrs = AttrSlice(*b);
+ int32 a_key = -1;
+ int32 b_key = -1;
+ Status s = GetNodeAttr(a_attrs, "instance_key", &a_key);
+ CHECK(s.ok());
+ s = GetNodeAttr(b_attrs, "instance_key", &b_key);
+ CHECK(s.ok());
+ return a_key < b_key;
+ }
+};
+
+struct NameLess {
+ bool operator()(const NodeDef* a, const NodeDef* b) const {
+ return a->name() < b->name();
+ }
+};
+
+bool IsCollectiveNode(const NodeDef& n) {
+ AttrSlice attrs = AttrSlice(n);
+ int key = -1;
+ if (!IsCollective(n)) return false;
+ Status s = GetNodeAttr(attrs, "instance_key", &key);
+ if (s.ok() && key >= 0) {
+ return true;
+ }
+ return false;
+}
+} // namespace
+
+Status ScopedAllocatorOptimizer::OrderNodeSet(
+ std::vector<NodeDef*>* nodes) const {
+ // Nodes should be identical type. Default order is by name but for
+ // collectives we order by increasing instance_key so each group gets
+ // the same instance_key.
+ if (nodes->size() <= 1) return Status::OK();
+ if (IsCollectiveNode(*nodes->at(0))) {
+ sort(nodes->begin(), nodes->end(), InstanceKeyLess());
+ } else {
+ sort(nodes->begin(), nodes->end(), NameLess());
+ }
+ return Status::OK();
+}
+
+} // namespace grappler
+} // namespace tensorflow
+
+#undef LOG_WARNING_AND_RETURN_IF_ERROR
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
new file mode 100644
index 0000000000..ab4d444595
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
@@ -0,0 +1,107 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+class Graph;
+class GraphProperties;
+class NodeMap;
+class ScopedAllocatorOptimizer;
+
+// An Optimizer that introduces ScopedAllocators in order to reduce data
+// movement and consolidate some kinds of Ops.
+class ScopedAllocatorOptimizer : public GraphOptimizer {
+ public:
+ explicit ScopedAllocatorOptimizer(const ScopedAllocatorOptions& opts);
+ ~ScopedAllocatorOptimizer() override;
+
+ string name() const override { return "scoped_allocator_optimizer"; }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ // Map from an Op name to a vector of Nodes with that Op.
+ typedef std::unordered_map<string, std::vector<NodeDef*>> DevOpOccurrences;
+ // Map from a device name to a DevOpOccurrences map.
+ typedef std::unordered_map<string, DevOpOccurrences> GraphOpOccurrences;
+ typedef std::unordered_set<string> OpNameSet;
+
+ Status ProcessGraphDef(GraphDef* graph,
+ const GraphProperties& graph_properties);
+
+ // Populates *occs by grouping Nodes with common Ops, according to
+ // their assigned devices.
+ void FindOpOccurrences(GraphDef* graph, const OpNameSet& op_names,
+ GraphOpOccurrences* occs);
+
+ // Returns a new, unused scope_id to be assigned to a ScopedAllocator that
+ // will allocate num_fields (> 0) separate tensors.
+ int NewScopedAllocatorId(int num_fields);
+
+ NodeMap* node_map() { return node_map_.get(); }
+
+ // Appends values to the attr value under name in node_def, if present.
+ // If not present does an assignment.
+ static void ExtendNodeAttr(StringPiece name, const std::vector<int32>& values,
+ NodeDef* node_def);
+
+ // Class that knows how to do graph rewriting for a particular kind of Op in
+ // order to take advantage of a ScopedAllocator.
+ class Rewriter {
+ public:
+ virtual ~Rewriter() {}
+
+ virtual Status Rewrite(ScopedAllocatorOptimizer* paopti, GraphDef* graph,
+ const string& op_name,
+ const std::vector<NodeDef*>& nodes,
+ bool* applied) = 0;
+
+ void SetGraphProperties(const GraphProperties& graph_properties) {
+ graph_properties_ = &graph_properties;
+ CHECK(graph_properties_);
+ }
+
+ protected:
+ const GraphProperties* graph_properties_;
+ };
+
+ private:
+ Rewriter* GetRewriter(const string& op_name);
+
+ Status OrderNodeSet(std::vector<NodeDef*>* nodes) const;
+
+ RewriterConfig::Toggle opt_level_;
+ std::unordered_set<string> nodes_to_preserve_;
+ OpNameSet op_name_set_;
+ std::unordered_map<string, Rewriter*> rewriters_;
+ std::vector<Rewriter*> to_delete_;
+ int next_sa_id_ = 1;
+ std::unique_ptr<NodeMap> node_map_;
+};
+
+} // namespace grappler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
new file mode 100644
index 0000000000..3a2859dc5f
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -0,0 +1,243 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
+
+#include <unordered_set>
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class ScopedAllocatorOptimizerTest : public ::testing::Test {
+ public:
+ std::unique_ptr<Session> CreateSession(const GraphDef& graph,
+ const ConfigProto& config) {
+ SessionOptions options;
+ options.config = config;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ Session* session = NewSession(options);
+ TF_CHECK_OK(session->Create(graph));
+ return std::unique_ptr<Session>(session);
+ }
+
+ std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
+ const std::vector<string>& fetch) {
+ SessionOptions options;
+ std::unique_ptr<tensorflow::Session> session(NewSession(options));
+ TF_CHECK_OK(session->Create(graph));
+ RunOptions run_options;
+ std::vector<Tensor> output_tensors;
+ TF_CHECK_OK(
+ session->Run(run_options, {}, fetch, fetch, &output_tensors, nullptr));
+ TF_CHECK_OK(session->Close());
+ return output_tensors;
+ }
+
+ // Constructs the following graph.
+ // (Flow is top to bottom, like nature intends.)
+ //
+ // The intended optimization is to have s1 and s2 allocate from
+ // an new ScopedAllocator, then replace a1 and a2 with a3 that
+ // reads from the backing buffer.
+ /*
+ a b c
+ \ / \ /
+ s1 s2
+ | |
+ a1 a2
+ | |
+ r1 r2
+ */
+ void BuildAbsGraph(GraphDef* graph_def) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
+
+ Output a =
+ ops::Const<float>(s.WithOpName("a"), {1.0, 0.0, 0.0, -1.0}, {2, 2});
+ Output b =
+ ops::Const<float>(s.WithOpName("b"), {1.0, -2.0, 3.0, 4.0}, {2, 2});
+ Output c =
+ ops::Const<float>(s.WithOpName("c"), {-5.0, -2.0, 0.0, -2.0}, {2, 2});
+ Output s1 = ops::Add(s.WithOpName("s1"), a, b);
+ Output s2 = ops::Add(s.WithOpName("s2"), b, c);
+ Output a1 = ops::Abs(s.WithOpName("a1"), s1);
+ Output a2 = ops::Abs(s.WithOpName("a2"), s2);
+ Output r1 = ops::Reshape(s.WithOpName("r1"), a1, {1, 4});
+ Output r2 = ops::Reshape(s.WithOpName("r2"), a2, {4, 1});
+ TF_CHECK_OK(s.ToGraphDef(graph_def));
+ }
+
+ void SetShapes(GraphDef* graph_def) {
+ TensorShapeProto shape_proto;
+ shape_proto.add_dim()->set_size(2);
+ shape_proto.add_dim()->set_size(2);
+
+ for (NodeDef& n : *graph_def->mutable_node()) {
+ if (n.op() == "Add" || n.op() == "Abs") {
+ AddNodeAttr("_output_shapes", {shape_proto}, &n);
+ }
+ }
+ }
+};
+
+TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
+ // Tests that Rewrite of program with parallel unary Ops is done as
+ // anticipated.
+ GrapplerItem item;
+ BuildAbsGraph(&item.graph);
+ SetShapes(&item.graph);
+
+ ScopedAllocatorOptions opts;
+ opts.add_enable_op("Abs");
+ ScopedAllocatorOptimizer sao(opts);
+ ScopedAllocatorOptimizer::OpNameSet ons;
+ ons.insert("Abs");
+
+ GraphDef optimized_graph;
+ TF_ASSERT_OK(sao.Optimize(nullptr /*cluster*/, item, &optimized_graph));
+
+ // Examine the resulting graph def.
+ NodeMap node_map(&optimized_graph);
+ NodeDef* nd = node_map.GetNode("scoped_allocator_1");
+ ASSERT_TRUE(nd);
+ {
+ auto& nd_set = node_map.GetOutputs(nd->name());
+ ASSERT_EQ(3, nd_set.size());
+ std::unordered_set<string> expected = {"scoped_allocator_concat_1", "s1",
+ "s2"};
+ for (auto it : nd_set) {
+ ASSERT_NE(expected.find(it->name()), expected.end())
+ << "Failed to find " << it->name();
+ }
+ }
+ {
+ auto& nd_set = node_map.GetOutputs("scoped_allocator_concat_1");
+ ASSERT_EQ(1, nd_set.size());
+ for (auto it : nd_set) {
+ ASSERT_EQ("scoped_allocator_1_Abs", it->name());
+ }
+ }
+ {
+ auto& nd_set = node_map.GetOutputs("scoped_allocator_1_Abs");
+ ASSERT_EQ(1, nd_set.size());
+ for (auto it : nd_set) {
+ ASSERT_EQ("scoped_allocator_split_1", it->name());
+ }
+ }
+ {
+ auto& nd_set = node_map.GetOutputs("scoped_allocator_split_1");
+ ASSERT_EQ(2, nd_set.size());
+ std::unordered_set<string> name_set;
+ for (auto it : nd_set) {
+ name_set.insert(it->name());
+ }
+ ASSERT_TRUE(name_set.find("r1") != name_set.end());
+ ASSERT_TRUE(name_set.find("r2") != name_set.end());
+ }
+}
+
+TEST_F(ScopedAllocatorOptimizerTest, UnaryExecute) {
+ // Constructs the same graph as UnaryRewriteOnly, but actually executes it.
+ GrapplerItem item;
+ BuildAbsGraph(&item.graph);
+
+ // Turn off all optimization except the ScopedAllocatorOptimizer
+ // to avoid anything that would alter the expected graph input/output,
+ // e.g. by constant folding away all calculations.
+ ConfigProto config;
+ GraphOptions* gopt = config.mutable_graph_options();
+ OptimizerOptions* opts = gopt->mutable_optimizer_options();
+ opts->set_do_common_subexpression_elimination(false);
+ opts->set_do_constant_folding(false);
+ opts->set_do_function_inlining(false);
+ opts->set_opt_level(OptimizerOptions::L0);
+ RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
+ rwcfg->clear_optimizers();
+ (*rwcfg->add_optimizers()) = "scoped_allocator";
+ rwcfg->mutable_scoped_allocator_opts()->add_enable_op("Abs");
+ std::unique_ptr<Session> session(CreateSession(item.graph, config));
+
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {"r1:0", "r2:0",
+ "scoped_allocator_1_Abs:0"};
+ std::vector<string> target_nodes = {};
+ std::vector<Tensor> outputs;
+ Status s = session->Run(inputs, output_names, target_nodes, &outputs);
+ TF_ASSERT_OK(s);
+ ASSERT_EQ(outputs.size(), 3);
+ std::vector<float> expected_r1({2, 2, 3, 3});
+ std::vector<float> expected_r2({4, 4, 3, 2});
+ // a + b == 2, -2, 3, 3
+ // b + c == -4, -4, 3, 2
+ for (int oi = 0; oi < outputs.size(); ++oi) {
+ for (int i = 0; i < outputs[oi].NumElements(); ++i) {
+ VLOG(0) << "output vec " << oi << " index " << i << " = "
+ << outputs[oi].flat<float>()(i);
+ }
+ if (oi == 0) {
+ ASSERT_EQ(expected_r1.size(), outputs[oi].NumElements());
+ for (int i = 0; i < expected_r1.size(); ++i) {
+ EXPECT_EQ(expected_r1[i], outputs[oi].flat<float>()(i));
+ }
+ } else if (oi == 1) {
+ ASSERT_EQ(expected_r2.size(), outputs[oi].NumElements());
+ for (int i = 0; i < expected_r2.size(); ++i) {
+ EXPECT_EQ(expected_r2[i], outputs[oi].flat<float>()(i));
+ }
+ }
+ }
+}
+
+// Tests static ScopedAllocatorOptimizer::ExtendNodeAttr.
+// Maybe this should be moved elsewhere?
+TEST_F(ScopedAllocatorOptimizerTest, Extend) {
+ NodeDef nd;
+ ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {0, 2}, &nd);
+ ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {6, 7}, &nd);
+ ScopedAllocatorOptimizer::ExtendNodeAttr("_scoped_allocator", {2, 3}, &nd);
+ VLOG(0) << "nd: " << nd.DebugString();
+ std::vector<int> scoped_allocator_attrs;
+ AttrSlice slice(nd);
+ Status sa_status =
+ GetNodeAttr(slice, "_scoped_allocator", &scoped_allocator_attrs);
+ for (int i : scoped_allocator_attrs) {
+ VLOG(0) << "extracted: " << i;
+ }
+ NodeDef nd2;
+ AddNodeAttr("_scoped_allocator", {0, 2}, &nd2);
+ AddNodeAttr("_scoped_allocator", {6, 7}, &nd2);
+ AddNodeAttr("_scoped_allocator", {2, 3}, &nd2);
+ VLOG(0) << "nd2: " << nd2.DebugString();
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 5a5dc47fa0..d64cb49715 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
@@ -525,7 +526,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
NodeDef* placeholder = function_body.add_node();
placeholder->set_name(input.name());
placeholder->set_op("Placeholder");
- (*placeholder->mutable_attr())["T"].set_type(input_data_type);
+ (*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
+ (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
+ true);
InputArgExpansion input_expansion{/*input_name=*/input.name(),
/*data_type=*/input_data_type,
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index 302f02dd39..8c3cc70351 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -256,7 +256,7 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "x" && count++) {
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "two" && count++) {
EXPECT_EQ("Const", node.op());
@@ -333,7 +333,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
if (node.name() == "x" || node.name() == "y" || node.name() == "dz") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "rx" && count++) {
EXPECT_EQ("BroadcastGradientArgs", node.op());
@@ -402,7 +402,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
if (node.name() == "x" || node.name() == "y") {
count++;
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "a0" && count++) {
EXPECT_EQ("Swap", node.op());
@@ -465,7 +465,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
for (const NodeDef &node : item.function_body().node()) {
if (node.name() == "in" && count++) {
EXPECT_EQ("Placeholder", node.op());
- EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
EXPECT_EQ(0, node.input_size());
} else if (node.name() == "Linear_func" && count++) {
EXPECT_EQ("Identity", node.op());
@@ -517,9 +517,9 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
count++;
EXPECT_EQ("Placeholder", node.op());
if (node.name() == "arg3") {
- EXPECT_EQ(DT_INT32, node.attr().at("T").type());
+ EXPECT_EQ(DT_INT32, node.attr().at("dtype").type());
} else {
- EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
}
}
EXPECT_EQ(5, count);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1bf6eafb58..5948f8d39f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1968,6 +1968,7 @@ tf_kernel_library(
tf_kernel_library(
name = "resource_variable_ops",
srcs = ["resource_variable_ops.cc"],
+ hdrs = ["resource_variable_ops.h"],
deps = [
":bounds_check",
":dense_update_functor",
@@ -5178,9 +5179,9 @@ filegroup(
"partitioned_function_ops.cc",
# Excluded due to experimental status:
"debug_ops.*",
- "scatter_nd_op*",
"mutex_ops.*",
"batch_kernels.*",
+ "regex_full_match_op.cc",
"regex_replace_op.cc",
],
),
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 6dfcd63ab3..53bdd482cb 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -255,7 +255,7 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
// node_ids
const Tensor* node_ids_t;
OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
- const auto node_ids = node_ids_t->vec<int32>();
+ const auto node_ids = node_ids_t->flat<int32>();
// gradients
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
@@ -268,12 +268,6 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
OpInputList bucketized_features_list;
OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
&bucketized_features_list));
- std::vector<tensorflow::TTypes<int32>::ConstVec> bucketized_features;
- bucketized_features.reserve(num_features_);
- for (const Tensor& tensor : bucketized_features_list) {
- bucketized_features.emplace_back(tensor.vec<int32>());
- }
-
// Infer batch size.
const int64 batch_size = node_ids_t->dim_size(0);
// Allocate output stats tensor (Rank 4).
@@ -282,18 +276,39 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
"stats_summary",
{num_features_, max_splits_, num_buckets_, 2},
&output_stats_summary_t));
- auto output_stats_summary = output_stats_summary_t->tensor<float, 4>();
- output_stats_summary.setZero();
+ auto output_stats_summary = output_stats_summary_t->flat<float>();
+ EIGEN_STATIC_ASSERT(
+ (static_cast<int>(decltype(output_stats_summary)::Layout) ==
+ static_cast<int>(Eigen::RowMajor)),
+ THIS_METHOD_IS_ONLY_FOR_ROW_MAJOR_MATRICES);
+
+ const int shift_per_node = num_buckets_ * 2;
+ const int shift_per_feature = shift_per_node * max_splits_;
+ const int32 max_index = num_features_ * shift_per_feature;
+ // We use double to sum the gradients and hessians, due to possible
+ // precision loss when summing small float values.
+ std::vector<double> res(max_index, 0);
// Partition by node, and then bucketize.
- for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
- const auto& features = bucketized_features[feature_idx];
+ int feature_idx = 0;
+ int feature_shift = 0;
+ for (const Tensor& tensor : bucketized_features_list) {
+ const auto& features = tensor.flat<int32>();
for (int i = 0; i < batch_size; ++i) {
const int32 node = node_ids(i);
const int32 bucket = features(i);
- output_stats_summary(feature_idx, node, bucket, 0) += gradients(i, 0);
- output_stats_summary(feature_idx, node, bucket, 1) += hessians(i, 0);
+ // Calculate the index in the flattened vector for
+ // [feature_idx][node][bucket][0].
+ const int index = feature_shift + node * shift_per_node + bucket * 2;
+ res[index] += gradients(i, 0);
+ res[index + 1] += hessians(i, 0);
}
+ ++feature_idx;
+ feature_shift += shift_per_feature;
+ }
+ // Copy over the results.
+ for (int i = 0; i < max_index; ++i) {
+ output_stats_summary(i) = res[i];
}
}
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
index 02cd298745..935619711c 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
@@ -21,5 +21,28 @@ REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8);
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("NotEqual")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::not_equal_to<int32>>);
#endif
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER2(BinaryOp, SYCL, "NotEqual", functor::not_equal_to, float, double);
+
+REGISTER_KERNEL_BUILDER(Name("NotEqual")
+ .Device(DEVICE_SYCL)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::not_equal_to<int32>>);
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index b6bf0ecd09..87bc8ebefe 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -438,6 +438,9 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
+// Note that IteratorHandleOp holds a reference to the resource it creates. If
+// cleaning up resources with DestroyResourceOp is important, consider creating
+// resource containers with AnonymousIteratorHandleOp instead.
class IteratorHandleOp : public OpKernel {
public:
explicit IteratorHandleOp(OpKernelConstruction* ctx)
@@ -574,6 +577,75 @@ class IteratorHandleOp : public OpKernel {
string name_;
};
+// Like IteratorHandleOp, but creates handles which are never shared, and does
+// not hold a reference to these handles. The latter is important for eager
+// execution, since OpKernel instances generally live as long as the program
+// running them.
+class AnonymousIteratorHandleOp : public OpKernel {
+ public:
+ explicit AnonymousIteratorHandleOp(OpKernelConstruction* context)
+ : OpKernel(context), graph_def_version_(context->graph_def_version()) {
+ OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context,
+ context->function_library()->Clone(&flib_def, &pflr, &lib));
+
+ ResourceMgr* mgr = context->resource_manager();
+
+ const string container_name = "AnonymousIterator";
+ string unique_name;
+ {
+ mutex_lock l(static_resource_lookup_mutex_);
+ while (true) { // Find an unused name
+ IteratorResource* existing_resource = nullptr;
+ unique_name = strings::StrCat("AnonymousIterator", current_id_++);
+ Status status = mgr->Lookup<IteratorResource>(
+ container_name, unique_name, &existing_resource);
+ if (status.code() == error::NOT_FOUND) {
+ break;
+ }
+ OP_REQUIRES_OK(context, status);
+ existing_resource->Unref();
+ }
+ IteratorResource* new_resource = new IteratorResource(
+ output_dtypes_, output_shapes_, graph_def_version_,
+ std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
+ // Create the resource with our chosen name under the resource lookup
+ // mutex to avoid another kernel racily creating a resource with this
+ // name.
+ OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
+ container_name, unique_name, new_resource));
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, container_name, unique_name,
+ MakeTypeIndex<IteratorResource>()));
+ }
+
+ private:
+ // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
+ // instances.
+ static mutex static_resource_lookup_mutex_;
+ // current_id_ is just a hint for creating unique names. If it turns out
+ // there's a collision (e.g. because another AnonymousIteratorHandleOp
+ // instance is generating handles) we'll just skip that id.
+ static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+};
+
+// Static initializers for AnonymousIteratorHandleOp id counting.
+mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
+ LINKER_INITIALIZED};
+int64 AnonymousIteratorHandleOp::current_id_(0);
+
class MakeIteratorOp : public OpKernel {
public:
explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1066,6 +1138,8 @@ class DeserializeIteratorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
+ AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 879bb40331..f41a810b07 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -211,6 +211,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
@@ -220,6 +221,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
@@ -243,6 +245,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
+ mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
@@ -629,6 +632,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ // Used for coordination between the main thread, the runner thread, and
+ // the callback threads.
mutex mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
@@ -636,6 +641,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Used for serializing external parallelism.
+ mutex external_mu_ ACQUIRED_BEFORE(mu_);
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 8f66f0a7b9..f2724735bf 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -254,6 +254,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
opts.runner = ctx->runner();
opts.stats_collector = ctx->stats_collector();
opts.step_container = ctx->step_container();
+ opts.collective_executor = ctx->collective_executor();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 2b010f816d..23fdfe944a 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -117,10 +117,6 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
}
}
- auto suppress_func = [iou_threshold](const float x) {
- return x <= iou_threshold ? 1 : 0;
- };
-
std::vector<int> selected;
std::vector<float> selected_scores;
Candidate next_candidate;
@@ -134,14 +130,14 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
// Overlapping boxes are likely to have similar scores,
// therefore we iterate through the previously selected boxes backwards
// in order to see if `next_candidate` should be suppressed.
+ bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
if (iou == 0.0) continue;
- next_candidate.score *= suppress_func(iou);
- if (next_candidate.score <= score_threshold) break;
+ if (iou > iou_threshold) should_select = false;
}
- if (original_score == next_candidate.score) {
+ if (should_select) {
selected.push_back(next_candidate.box_index);
selected_scores.push_back(next_candidate.score);
}
@@ -178,7 +174,7 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
- const float score_threshold_val = 0.0;
+ const float score_threshold_val = std::numeric_limits<float>::lowest();
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
iou_threshold_, score_threshold_val);
}
@@ -211,7 +207,7 @@ class NonMaxSuppressionV2Op : public OpKernel {
iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
- const float score_threshold_val = 0.0;
+ const float score_threshold_val = std::numeric_limits<float>::lowest();
DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
iou_threshold_val, score_threshold_val);
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index c71aa23e01..ed7db313bd 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -86,6 +86,23 @@ TEST_F(NonMaxSuppressionOpTest, TestSelectAtMostTwoBoxesFromThreeClusters) {
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
}
+TEST_F(NonMaxSuppressionOpTest, TestSelectWithNegativeScores) {
+ MakeOp(.5);
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(
+ TensorShape({6}), {.9f - 10.0f, .75f - 10.0f, .6f - 10.0f, .95f - 10.0f,
+ .5f - 10.0f, .3f - 10.0f});
+ AddInputFromArray<int>(TensorShape({}), {6});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
TEST_F(NonMaxSuppressionOpTest, TestSelectAtMostThirtyBoxesFromThreeClusters) {
MakeOp(.5);
AddInputFromArray<float>(
@@ -394,6 +411,27 @@ TEST_F(NonMaxSuppressionV3OpTest,
}
TEST_F(NonMaxSuppressionV3OpTest,
+ TestSelectFromThreeClustersWithScoreThresholdZeroScores) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.1, 0, 0, .3, .2, -5.0});
+ // If we ask for more boxes than we actually expect to get back;
+ // should still only get 2 boxes back.
+ AddInputFromArray<int>(TensorShape({}), {6});
+ AddInputFromArray<float>(TensorShape({}), {0.5f});
+ AddInputFromArray<float>(TensorShape({}), {-3.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+ test::FillValues<int>(&expected, {3, 0});
+
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionV3OpTest,
TestSelectFromThreeClustersFlippedCoordinates) {
MakeOp();
AddInputFromArray<float>(TensorShape({6, 4}),
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 03cc414905..af921e4815 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,6 +51,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -72,40 +73,33 @@ namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
-class ReadVariableOp : public OpKernel {
- public:
- explicit ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
- }
-
- void Compute(OpKernelContext* ctx) override {
- Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
- const auto status = LookupResource(ctx, handle, &variable);
- OP_REQUIRES(ctx, status.ok(),
- errors::FailedPrecondition(
- "Error while reading resource variable ", handle.name(),
- " from Container: ", handle.container(),
- ". This could mean that the variable was uninitialized. ",
- status.ToString()));
-
- core::ScopedUnref s(variable);
- // We're acquiring a reference to the underlying buffer while
- // holding a shared lock to guarantee ordering of reads and
- // writes.
- tf_shared_lock ml(*variable->mu());
- const Tensor& t = *variable->tensor();
- OP_REQUIRES(
- ctx, dtype_ == t.dtype(),
- errors::InvalidArgument(
- "Trying to read variable with wrong dtype. Expected ",
- DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
- ctx->set_output(0, t);
- }
-
- private:
- DataType dtype_;
-};
+ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+}
+
+void ReadVariableOp::Compute(OpKernelContext* ctx) {
+ Var* variable = nullptr;
+ ResourceHandle handle = HandleFromInput(ctx, 0);
+ const auto status = LookupResource(ctx, handle, &variable);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable ", handle.name(),
+ " from Container: ", handle.container(),
+ ". This could mean that the variable was uninitialized. ",
+ status.ToString()));
+
+ core::ScopedUnref s(variable);
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variable->mu());
+ const Tensor& t = *variable->tensor();
+ OP_REQUIRES(ctx, dtype_ == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. Expected ",
+ DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
+ ctx->set_output(0, t);
+}
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
new file mode 100644
index 0000000000..8cae5d21f0
--- /dev/null
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class ReadVariableOp : public OpKernel {
+ public:
+ explicit ReadVariableOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ DataType dtype_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_
diff --git a/tensorflow/core/kernels/scoped_allocator_ops_test.cc b/tensorflow/core/kernels/scoped_allocator_ops_test.cc
index d2918d2042..634f9ba887 100644
--- a/tensorflow/core/kernels/scoped_allocator_ops_test.cc
+++ b/tensorflow/core/kernels/scoped_allocator_ops_test.cc
@@ -37,10 +37,12 @@ namespace tensorflow {
class ScopedAllocatorOpTest : public OpsTestBase {
protected:
- void MakeOp(const gtl::ArraySlice<TensorShape>& shapes, DataType dtype,
+ void MakeOp(const TensorShape& shape,
+ const gtl::ArraySlice<TensorShape>& shapes, DataType dtype,
const string& name, int32 id, int32 expected_call_count) {
TF_EXPECT_OK(NodeDefBuilder("scoped_allocator_op", "_ScopedAllocator")
.Attr("T", dtype)
+ .Attr("shape", shape)
.Attr("shapes", shapes)
.Attr("sa_name", name)
.Attr("id", id)
@@ -61,12 +63,14 @@ class ScopedAllocatorOpTest : public OpsTestBase {
};
TEST_F(ScopedAllocatorOpTest, Simple) {
- MakeOp({TensorShape({8})}, DT_FLOAT, "test", 120, 1);
- MakeOp({TensorShape({32, 32})}, DT_DOUBLE, "test1", 130, 1);
- MakeOp({TensorShape({64}), TensorShape({3, 3}), TensorShape({5, 5, 5})},
+ MakeOp(TensorShape({8}), {TensorShape({8})}, DT_FLOAT, "test", 120, 1);
+ MakeOp(TensorShape({1024}), {TensorShape({32, 32})}, DT_DOUBLE, "test1", 130,
+ 1);
+ MakeOp(TensorShape({204}),
+ {TensorShape({64}), TensorShape({3, 3}), TensorShape({5, 5, 5})},
DT_HALF, "test2", 140, 3);
- MakeOp({TensorShape({512}), TensorShape({64, 8})}, DT_UINT32, "test3", 150,
- 2);
+ MakeOp(TensorShape({1024}), {TensorShape({512}), TensorShape({64, 8})},
+ DT_UINT32, "test3", 150, 2);
}
// PrepOp is common to ConcatOp tests and SplitOpTests.
@@ -254,23 +258,26 @@ TEST_F(ScopedAllocatorConcatOpTest, FailBounds) {
class ScopedAllocatorSplitOpTest : public OpsTestBase {
protected:
- void BuildNodeDef(const TensorShape& shape, DataType dtype,
- const string& name, int32 id, int32 num_tensors) {
+ void BuildNodeDef(const TensorShape& in_shape, DataType dtype,
+ const string& name, int32 id, int32 num_tensors,
+ const std::vector<TensorShape>& out_shapes) {
TF_EXPECT_OK(
NodeDefBuilder("scoped_allocator_split_op", "_ScopedAllocatorSplit")
.Attr("T", dtype)
.Attr("N", num_tensors)
.Attr("sa_name", name)
.Attr("id", id)
+ .Attr("shapes", out_shapes)
.Input(FakeInput(dtype)) // backing tensor and input
.Input(
FakeInput(num_tensors, dtype)) // list of subtensors to forward
.Finalize(node_def()));
}
- void MakeOp(const TensorShape& shape, DataType dtype, const string& name,
- int32 id, int32 num_tensors) {
- BuildNodeDef(shape, dtype, name, id, num_tensors);
+ void MakeOp(const TensorShape& in_shape, DataType dtype, const string& name,
+ int32 id, int32 num_tensors,
+ const std::vector<TensorShape>& out_shapes) {
+ BuildNodeDef(in_shape, dtype, name, id, num_tensors, out_shapes);
TF_EXPECT_OK(InitOp());
}
@@ -310,33 +317,33 @@ class ScopedAllocatorSplitOpTest : public OpsTestBase {
};
TEST_F(ScopedAllocatorSplitOpTest, Success1) {
- MakeOp({32}, DT_FLOAT, "test", 120, 2);
+ MakeOp({32}, DT_FLOAT, "test", 120, 2, {{16}, {16}});
ExecOp(DT_FLOAT, 120, {{16}, {16}});
}
TEST_F(ScopedAllocatorSplitOpTest, Success2) {
- MakeOp({2, 2, 2}, DT_DOUBLE, "test", 120, 2);
+ MakeOp({2, 2, 2}, DT_DOUBLE, "test", 120, 2, {{2, 2}, {2, 2}});
ExecOp(DT_DOUBLE, 120, {{2, 2}, {2, 2}});
}
TEST_F(ScopedAllocatorSplitOpTest, Success3) {
- MakeOp({3, 3, 3}, DT_HALF, "test", 120, 3);
+ MakeOp({3, 3, 3}, DT_HALF, "test", 120, 3, {{3, 3}, {3, 3}, {3, 3}});
ExecOp(DT_HALF, 120, {{3, 3}, {3, 3}, {3, 3}});
}
TEST_F(ScopedAllocatorSplitOpTest, FailNLessThan2) {
- BuildNodeDef({4, 4}, DT_FLOAT, "test", 120, 1);
+ BuildNodeDef({4, 4}, DT_FLOAT, "test", 120, 1, {{4, 4}});
Status s = InitOp();
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
}
TEST_F(ScopedAllocatorSplitOpTest, FailDtypeCheck) {
- MakeOp({8}, DT_FLOAT, "test", 120, 2);
+ MakeOp({8}, DT_FLOAT, "test", 120, 2, {{4}, {4}});
EXPECT_DEATH(ExecOp(DT_HALF, 120, {{4}, {4}}), "");
}
TEST_F(ScopedAllocatorSplitOpTest, FailBounds) {
- MakeOp({8}, DT_DOUBLE, "test", 120, 2);
+ MakeOp({8}, DT_DOUBLE, "test", 120, 2, {{4}, {4}});
AddInputFromArray<double>({8}, {0, 1, 2, 3, 4, 5, 6, 7});
AddInputFromArray<double>({4}, {0, 1, 2, 3});
AddInputFromArray<double>({4}, {4, 5, 6, 7});
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index c867674489..1920d0a592 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -1500,6 +1500,26 @@ op {
}
}
op {
+ name: "AnonymousIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Any"
input_arg {
name: "input"
@@ -14195,6 +14215,92 @@ op {
}
}
op {
+ name: "Conv3DBackpropInputV2"
+ input_arg {
+ name: "input_sizes"
+ type_attr: "Tshape"
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NDHWC"
+ }
+ allowed_values {
+ list {
+ s: "NDHWC"
+ s: "NCDHW"
+ }
+ }
+ }
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ name: "Tshape"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Copy"
input_arg {
name: "input"
@@ -42362,6 +42468,21 @@ op {
allows_uninitialized_input: true
}
op {
+ name: "RegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pattern"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+}
+op {
name: "RegexReplace"
input_arg {
name: "input"
@@ -55296,6 +55417,56 @@ op {
}
}
op {
+ name: "SegmentMean"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "SegmentMin"
input_arg {
name: "data"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 29d9cfbde9..046049b678 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -584,6 +584,12 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("AnonymousIterator")
+ .Output("handle: resource")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MakeIterator")
.Input("dataset: variant")
.Input("iterator: resource")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index e45125a1e8..d929a5fc87 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -650,6 +650,26 @@ op {
}
}
op {
+ name: "AnonymousIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Any"
input_arg {
name: "input"
@@ -5927,7 +5947,7 @@ op {
name: "Conv3DBackpropInputV2"
input_arg {
name: "input_sizes"
- type: DT_INT32
+ type_attr: "Tshape"
}
input_arg {
name: "filter"
@@ -5995,6 +6015,19 @@ op {
}
}
}
+ attr {
+ name: "Tshape"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
}
op {
name: "Copy"
@@ -21495,6 +21528,21 @@ op {
allows_uninitialized_input: true
}
op {
+ name: "RegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pattern"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+}
+op {
name: "RegexReplace"
input_arg {
name: "input"
@@ -26047,9 +26095,14 @@ op {
type: DT_UINT8
type: DT_INT16
type: DT_INT8
+ type: DT_COMPLEX64
type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
type: DT_BFLOAT16
type: DT_UINT16
+ type: DT_COMPLEX128
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
diff --git a/tensorflow/core/ops/scoped_allocator_ops.cc b/tensorflow/core/ops/scoped_allocator_ops.cc
index 1e0dcdac96..359b4d8756 100644
--- a/tensorflow/core/ops/scoped_allocator_ops.cc
+++ b/tensorflow/core/ops/scoped_allocator_ops.cc
@@ -21,6 +21,7 @@ namespace tensorflow {
REGISTER_OP("_ScopedAllocator")
.Output("output: T")
.Attr("shapes: list(shape)")
+ .Attr("shape: shape")
.Attr("T: type")
.Attr("sa_name: string")
.Attr("id: int")
@@ -35,6 +36,16 @@ Returns a reference to this value.
This is an experimental op for internal use only. It is possible to use this
op in unsafe ways.
+
+'shapes' is a list of the shapes of the tensors that are to be allocated
+by this ScopedAllocator.
+'shape' is the shape of the output of this Op, i.e. the 1D backing tensor
+from which the individual allocated tensors are aliased.
+'sa_name' is the name assigned to the Node, for connectivity specification
+and debugging.
+'id' is a non-negative integer 'scope_id' handled by the ScopedAllocatorMgr.
+'expected_call_count' is the number of individual tensors expected to
+be allocated from the backing tensor.
)doc");
REGISTER_OP("_ScopedAllocatorConcat")
@@ -57,6 +68,18 @@ reference to that ScopedAllocator's backing tensor.
This is an experimental op for internal use only. It is possible to use this
op in unsafe ways.
+
+'backing' is the backing tensor, i.e. the output of an upstream ScopedAllocator.
+'inputs' is a list of nominal input tensors, all of which must be aliases
+to regions of the backing tensor. These will be outputs of upstream nodes
+that allocate their outputs from the same ScopedAllocator.
+'shape' is the shape of the output, which will usually be the same shape as
+the input backing tensor.
+'reshape' is true iff the output shape is to be different from that of
+the input backing tensor.
+'sa_name' is the Node name of the upstream ScopedAllocator.
+'id' is the scope_id identifying the upstream ScopedAllocator.
+'N' is the number of nominal inputs to be concatenated.
)doc");
REGISTER_OP("_ScopedAllocatorSplit")
@@ -67,8 +90,9 @@ REGISTER_OP("_ScopedAllocatorSplit")
.Attr("sa_name: string")
.Attr("id: int")
.Attr("N: int >= 2")
+ .Attr("shapes: list(shape)")
.SetIsStateful()
- .SetShapeFn(shape_inference::ExplicitShape)
+ .SetShapeFn(shape_inference::ExplicitShapes)
.Doc(R"doc(
Acts roughly like a SplitV Op that splits one tensor into multiple tensors
but must only be used in conjunction with corresponding ScopedAllocator
@@ -79,6 +103,17 @@ second list.
This is an experimental op for internal use only. It is possible to use this
op in unsafe ways.
+
+'concat' is the single output produced by an upstream ScopedAllocatorConcat
+node. This is actually the backing tensor from a ScopedAllocator node
+upstream of the ScopedAllocatorConcat.
+'split' is a list of tensors aliased from the backing tensor. It will
+become the output of this ScopedAllocatorSplit node.
+'type' is the common DataType of all of the input and output tensors.
+'sa_name' is the Node name of the upstream ScopedAllocator.
+'id' is the scope_id identifying the upstream ScopedAllocator.
+'N' is the number of split tensors.
+'shapes' is a list of the split tensor shapes.
)doc");
} // end namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index 081d4cf043..a1be4aacce 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -112,10 +112,6 @@ class LibCurlProxy : public LibCurl {
}
void curl_free(void* p) override { ::curl_free(p); }
-
- const char* curl_easy_strerror(CURLcode errornum) override {
- return ::curl_easy_strerror(errornum);
- }
};
} // namespace
@@ -313,7 +309,7 @@ void CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) {
CHECK(buffer != nullptr);
CheckNotSent();
- direct_response_ = DirectResponseState{buffer, size, 0};
+ direct_response_ = DirectResponseState{buffer, size, 0, 0};
CHECK_CURL_OK(libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
reinterpret_cast<void*>(this)));
CHECK_CURL_OK(libcurl_->curl_easy_setopt(
@@ -335,24 +331,15 @@ size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size,
size_t curl_bytes_received = size * nmemb;
size_t user_buffer_bytes_available =
state->buffer_size_ - state->bytes_transferred_;
-
- // The HTTP server may send a response body that is longer than what we
- // expected. We must not use CHECK() for this situation, because that would
- // imply a code bug (in this client code) where none exists; the violation of
- // expectations would have been caused by the server, not the client. So we
- // report a log warning, if an HTTP server is misbehaving.
- if (curl_bytes_received > user_buffer_bytes_available) {
- LOG(WARNING) << "The HTTP response body that we received is longer than we "
- "requested or expected. "
- << "Total bytes requested: " << state->buffer_size_
- << " Bytes received (so far) in HTTP response body: "
- << (state->bytes_transferred_ + curl_bytes_received);
- }
-
size_t bytes_to_copy =
std::min<size_t>(curl_bytes_received, user_buffer_bytes_available);
memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy);
state->bytes_transferred_ += bytes_to_copy;
+ state->bytes_received_ += curl_bytes_received;
+ // If we didn't have room to store the full response, returning less than
+ // curl_bytes_received here will abort the transfer and curl_easy_perform()
+ // will return CURLE_WRITE_ERROR. We will detect and handle this error there,
+ // and can use state->bytes_received_ as stored above for logging purposes.
return bytes_to_copy;
}
@@ -447,23 +434,7 @@ Status CurlHttpRequest::Send() {
}
const CURLcode curl_result = libcurl_->curl_easy_perform(curl_);
- TF_CURL_RETURN_WITH_CONTEXT_IF_ERROR(
- curl_result, "Performing request. Detailed error: ", error_buffer);
-
- auto get_error_message = [this, curl_result, &error_buffer]() -> string {
- StringPiece response = GetResponse();
- string error_message = strings::StrCat(
- "Error executing an HTTP request (HTTP response code ", response_code_,
- ", error code ", curl_result, ", error message '", error_buffer, "')");
- if (!response.empty()) {
- return strings::StrCat(
- error_message, ", response '",
- response.substr(0,
- std::min(response.size(), response_to_error_limit_)),
- "'");
- }
- return error_message;
- };
+ TF_RETURN_IF_ERROR(CURLcodeToStatus(curl_result, error_buffer));
double written_size = 0;
CHECK_CURL_OK(libcurl_->curl_easy_getinfo(curl_, CURLINFO_SIZE_DOWNLOAD,
@@ -472,6 +443,18 @@ Status CurlHttpRequest::Send() {
CHECK_CURL_OK(libcurl_->curl_easy_getinfo(curl_, CURLINFO_RESPONSE_CODE,
&response_code_));
+ auto get_error_message = [this]() -> string {
+ string error_message = strings::StrCat(
+ "Error executing an HTTP request: HTTP response code ", response_code_);
+ StringPiece body = GetResponse();
+ if (!body.empty()) {
+ return strings::StrCat(
+ error_message, " with body '",
+ body.substr(0, std::min(body.size(), response_to_error_limit_)), "'");
+ }
+ return error_message;
+ };
+
Status result;
switch (response_code_) {
// The group of response codes indicating that the request achieved
@@ -485,9 +468,12 @@ Status CurlHttpRequest::Send() {
case 416: // Requested Range Not Satisfiable
// The requested range had no overlap with the available range.
- // This doesn't indicate an error, but this does mean an empty response
- // body.
+ // This doesn't indicate an error, but we should produce an empty response
+ // body. (Not all servers do; GCS returns a short error message body.)
response_buffer_->clear();
+ if (IsDirectResponse()) {
+ direct_response_.bytes_transferred_ = 0;
+ }
result = Status::OK();
break;
@@ -613,14 +599,13 @@ int CurlHttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal,
<< " bytes for " << now - that->last_progress_timestamp_
<< " seconds and will be aborted. CURL timing information: "
<< "lookup time: " << lookup_time << " ("
- << that->libcurl_->curl_easy_strerror(lookup_time_status)
+ << curl_easy_strerror(lookup_time_status)
<< "), connect time: " << connect_time << " ("
- << that->libcurl_->curl_easy_strerror(connect_time_status)
+ << curl_easy_strerror(connect_time_status)
<< "), pre-transfer time: " << pretransfer_time << " ("
- << that->libcurl_->curl_easy_strerror(pretransfer_time_status)
+ << curl_easy_strerror(pretransfer_time_status)
<< "), start-transfer time: " << starttransfer_time << " ("
- << that->libcurl_->curl_easy_strerror(starttransfer_time_status)
- << ")";
+ << curl_easy_strerror(starttransfer_time_status) << ")";
return 1; // Will abort the request.
}
@@ -628,12 +613,36 @@ int CurlHttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal,
return 0;
}
-Status CURLcodeToStatus(CURLcode code) {
- // Return Unavailable to retry by default. We probably should distinguish
- // between permanent or temporary failures.
- return errors::Unavailable("Error executing an HTTP request (error code ",
- code, ", error message '",
- curl_easy_strerror(code), "')");
+Status CurlHttpRequest::CURLcodeToStatus(CURLcode code,
+ const char* error_buffer) {
+ if (code == CURLE_OK) {
+ return Status::OK();
+ }
+ string error_message = strings::StrCat(
+ "Error executing an HTTP request: libcurl code ", code, " meaning '",
+ curl_easy_strerror(code), "', error details: ");
+ // Special-case response-too-large errors as FAILED_PRECONDITION.
+ if (code == CURLE_WRITE_ERROR && IsDirectResponse() &&
+ direct_response_.bytes_received_ > direct_response_.buffer_size_) {
+ string overflow_message = strings::StrCat(
+ "Received ", direct_response_.bytes_received_, " response bytes ",
+ "for a ", direct_response_.buffer_size_, "-byte buffer");
+ uint64 response_code = 0;
+ const CURLcode get_response_result = libcurl_->curl_easy_getinfo(
+ curl_, CURLINFO_RESPONSE_CODE, &response_code);
+ // Special-case 416 Range Not Satisfied responses; they sometimes have
+ // a response body (e.g. GCS sends one with an error message) but we
+ // pretend as though they don't, so actually ignore this error.
+ if (get_response_result == CURLE_OK && response_code == 416) {
+ return Status::OK();
+ }
+ return errors::FailedPrecondition(
+ strings::StrCat(error_message, overflow_message));
+ }
+ // Return Unavailable to retry by default. There may be other permanent
+ // failures that should be distinguished.
+ return errors::Unavailable(
+ strings::StrCat(error_message, *error_buffer ? error_buffer : "(none)"));
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h
index e658948ab9..1b2029926d 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.h
+++ b/tensorflow/core/platform/cloud/curl_http_request.h
@@ -167,6 +167,10 @@ class CurlHttpRequest : public HttpRequest {
void CheckNotSent() const;
StringPiece GetResponse() const;
+ /// Helper to convert the given CURLcode and error buffer, representing the
+ /// result of performing a transfer, into a Status with an error message.
+ Status CURLcodeToStatus(CURLcode code, const char* error_buffer);
+
LibCurl* libcurl_;
Env* env_;
@@ -181,6 +185,7 @@ class CurlHttpRequest : public HttpRequest {
char* buffer_;
size_t buffer_size_;
size_t bytes_transferred_;
+ size_t bytes_received_;
};
DirectResponseState direct_response_ = {};
@@ -261,21 +266,8 @@ class LibCurl {
virtual void curl_slist_free_all(curl_slist* list) = 0;
virtual char* curl_easy_escape(CURL* curl, const char* str, int length) = 0;
virtual void curl_free(void* p) = 0;
-
- virtual const char* curl_easy_strerror(CURLcode errornum) = 0;
};
-Status CURLcodeToStatus(CURLcode code);
-
-#define TF_CURL_RETURN_WITH_CONTEXT_IF_ERROR(_code, ...) \
- do { \
- if (_code != CURLE_OK) { \
- ::tensorflow::Status _status = ::tensorflow::CURLcodeToStatus(_code); \
- ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \
- return _status; \
- } \
- } while (0)
-
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_
diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc
index 522b717568..eb9023d708 100644
--- a/tensorflow/core/platform/cloud/curl_http_request_test.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc
@@ -149,8 +149,12 @@ class FakeLibCurl : public LibCurl {
} while (bytes_read > 0);
}
if (write_data_ || write_callback_) {
- write_callback_(response_content_.c_str(), 1, response_content_.size(),
- write_data_);
+ size_t bytes_handled = write_callback_(
+ response_content_.c_str(), 1, response_content_.size(), write_data_);
+ // Mimic real libcurl behavior by checking write callback return value.
+ if (bytes_handled != response_content_.size()) {
+ curl_easy_perform_result_ = CURLE_WRITE_ERROR;
+ }
}
for (const auto& header : response_headers_) {
header_callback_(header.c_str(), 1, header.size(), header_data_);
@@ -219,10 +223,6 @@ class FakeLibCurl : public LibCurl {
}
void curl_free(void* p) override { port::Free(p); }
- const char* curl_easy_strerror(CURLcode errornum) override {
- return "<unimplemented>";
- }
-
// Variables defining the behavior of this fake.
string response_content_;
uint64 response_code_;
@@ -302,7 +302,7 @@ TEST(CurlHttpRequestTest, GetRequest_Direct) {
string expected_response = "get response";
size_t response_bytes_transferred =
http_request.GetResultBufferDirectBytesTransferred();
- EXPECT_EQ(response_bytes_transferred, expected_response.size());
+ EXPECT_EQ(expected_response.size(), response_bytes_transferred);
EXPECT_EQ(
"get response",
string(scratch.begin(), scratch.begin() + response_bytes_transferred));
@@ -318,6 +318,48 @@ TEST(CurlHttpRequestTest, GetRequest_Direct) {
EXPECT_EQ(200, http_request.GetResponseCode());
}
+TEST(CurlHttpRequestTest, GetRequest_Direct_ResponseTooLarge) {
+ FakeLibCurl libcurl("get response", 200);
+ CurlHttpRequest http_request(&libcurl);
+
+ std::vector<char> scratch(5, 0);
+
+ http_request.SetUri("http://www.testuri.com");
+ http_request.SetResultBufferDirect(scratch.data(), scratch.size());
+ const Status& status = http_request.Send();
+ EXPECT_EQ(error::FAILED_PRECONDITION, status.code());
+ EXPECT_EQ(
+ "Error executing an HTTP request: libcurl code 23 meaning "
+ "'Failed writing received data to disk/application', error details: "
+ "Received 12 response bytes for a 5-byte buffer",
+ status.error_message());
+
+ // As long as the request clearly fails, ok to leave truncated response here.
+ EXPECT_EQ(5, http_request.GetResultBufferDirectBytesTransferred());
+ EXPECT_EQ("get r", string(scratch.begin(), scratch.begin() + 5));
+}
+
+TEST(CurlHttpRequestTest, GetRequest_Direct_RangeOutOfBound) {
+ FakeLibCurl libcurl("get response", 416);
+ CurlHttpRequest http_request(&libcurl);
+
+ const string initialScratch = "abcde";
+ std::vector<char> scratch;
+ scratch.insert(scratch.end(), initialScratch.begin(), initialScratch.end());
+
+ http_request.SetUri("http://www.testuri.com");
+ http_request.SetRange(0, 4);
+ http_request.SetResultBufferDirect(scratch.data(), scratch.size());
+ TF_EXPECT_OK(http_request.Send());
+ EXPECT_EQ(416, http_request.GetResponseCode());
+
+ // Some servers (in particular, GCS) return an error message payload with a
+ // 416 Range Not Satisfiable response. We should pretend it's not there when
+ // reporting bytes transferred, but it's ok if it writes to scratch.
+ EXPECT_EQ(0, http_request.GetResultBufferDirectBytesTransferred());
+ EXPECT_EQ("get r", string(scratch.begin(), scratch.end()));
+}
+
TEST(CurlHttpRequestTest, GetRequest_Empty) {
FakeLibCurl libcurl("", 200);
CurlHttpRequest http_request(&libcurl);
@@ -357,28 +399,26 @@ TEST(CurlHttpRequestTest, GetRequest_RangeOutOfBound) {
http_request.SetResultBuffer(&scratch);
TF_EXPECT_OK(http_request.Send());
+ // Some servers (in particular, GCS) return an error message payload with a
+ // 416 Range Not Satisfiable response. We should pretend it's not there.
EXPECT_TRUE(scratch.empty());
EXPECT_EQ(416, http_request.GetResponseCode());
}
TEST(CurlHttpRequestTest, GetRequest_503) {
FakeLibCurl libcurl("get response", 503);
- libcurl.curl_easy_perform_result_ = CURLE_WRITE_ERROR;
CurlHttpRequest http_request(&libcurl);
std::vector<char> scratch;
scratch.insert(scratch.end(), kTestContent.begin(), kTestContent.end());
http_request.SetUri("http://www.testuri.com");
- http_request.AddAuthBearerHeader("fake-bearer");
- http_request.SetRange(100, 199);
http_request.SetResultBuffer(&scratch);
const auto& status = http_request.Send();
EXPECT_EQ(error::UNAVAILABLE, status.code());
EXPECT_EQ(
- "Error executing an HTTP request (error code 23, error message 'Failed "
- "writing received data to disk/application')\n\tPerforming request. "
- "Detailed error: ",
+ "Error executing an HTTP request: HTTP response code 503 with body "
+ "'get response'",
status.error_message());
}
@@ -395,9 +435,8 @@ TEST(CurlHttpRequestTest, GetRequest_HttpCode0) {
const auto& status = http_request.Send();
EXPECT_EQ(error::UNAVAILABLE, status.code());
EXPECT_EQ(
- "Error executing an HTTP request (error code 28, error message 'Timeout "
- "was reached')\n\tPerforming request. Detailed error: Operation timed "
- "out",
+ "Error executing an HTTP request: libcurl code 28 meaning "
+ "'Timeout was reached', error details: Operation timed out",
status.error_message());
EXPECT_EQ(0, http_request.GetResponseCode());
}
@@ -630,9 +669,8 @@ TEST(CurlHttpRequestTest, ProgressIsStuck) {
auto status = http_request.Send();
EXPECT_EQ(error::UNAVAILABLE, status.code());
EXPECT_EQ(
- "Error executing an HTTP request (error code 42, error message "
- "'Operation was aborted by an application callback')\n\tPerforming "
- "request. Detailed error: ",
+ "Error executing an HTTP request: libcurl code 42 meaning 'Operation "
+ "was aborted by an application callback', error details: (none)",
status.error_message());
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index e91c3b1b0c..dc12c78a4b 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -1103,7 +1103,8 @@ Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
}
};
GcsFileStat stat;
- Status s = stat_cache_->LookupOrCompute(dirname, &stat, compute_func);
+ Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat,
+ compute_func);
if (s.ok()) {
*result = stat.base.is_directory;
return Status::OK();
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index bb4ace65a9..3f73b238ad 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -1107,7 +1107,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
"\"updated\": \"2016-04-29T23:15:24.896Z\"}")),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
- "path%2Fsubfolder?fields=size%2Cgeneration%2Cupdated\n"
+ "path%2Fsubfolder%2F?fields=size%2Cgeneration%2Cupdated\n"
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404),
@@ -1133,7 +1133,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
// HTTP requests.
for (int i = 0; i < 10; i++) {
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
- TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
+ TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder/"));
}
}
@@ -1932,6 +1932,14 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
"Range: 0-15\n"
"Timeouts: 5 1 20\n",
"76543210"),
+ // IsDirectory is checking whether there are children objects.
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F"
+ "&maxResults=1\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ "{}"),
// Copying to the new location.
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
@@ -2318,7 +2326,7 @@ TEST(GcsFileSystemTest, Stat_Cache) {
"\"updated\": \"2016-04-29T23:15:24.896Z\"}")),
new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
- "subfolder?fields=size%2Cgeneration%2Cupdated\n"
+ "subfolder%2F?fields=size%2Cgeneration%2Cupdated\n"
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404),
@@ -2348,7 +2356,7 @@ TEST(GcsFileSystemTest, Stat_Cache) {
EXPECT_EQ(1010, stat.length);
EXPECT_NEAR(1461971724896, stat.mtime_nsec / 1000 / 1000, 1);
EXPECT_FALSE(stat.is_directory);
- TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
+ TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder/", &stat));
EXPECT_EQ(0, stat.length);
EXPECT_EQ(0, stat.mtime_nsec);
EXPECT_TRUE(stat.is_directory);
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index b4b756b866..23c594d90d 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -82,7 +82,7 @@ def pyx_library(
native.cc_binary(
name=shared_object_name,
srcs=[stem + ".cpp"],
- deps=deps + ["//util/python:python_headers"],
+ deps=deps + ["//third_party/python_runtime:headers"],
linkshared = 1,
)
shared_objects.append(shared_object_name)
@@ -495,14 +495,6 @@ def tf_additional_lib_srcs(exclude = []):
], exclude = exclude),
})
-# pylint: disable=unused-argument
-def tf_additional_framework_hdrs(exclude = []):
- return []
-
-def tf_additional_framework_srcs(exclude = []):
- return []
-# pylint: enable=unused-argument
-
def tf_additional_minimal_lib_srcs():
return [
"platform/default/integral_types.h",
diff --git a/tensorflow/core/platform/default/string_coding.cc b/tensorflow/core/platform/default/string_coding.cc
new file mode 100644
index 0000000000..7410ee6782
--- /dev/null
+++ b/tensorflow/core/platform/default/string_coding.cc
@@ -0,0 +1,30 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/default/string_coding.h"
+
+namespace tensorflow {
+namespace port {
+
+std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
+ return std::unique_ptr<StringListEncoder>(new StringListEncoder(out));
+}
+
+std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
+ return std::unique_ptr<StringListDecoder>(new StringListDecoder(in));
+}
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/string_coding.h b/tensorflow/core/platform/default/string_coding.h
new file mode 100644
index 0000000000..70b8ab0144
--- /dev/null
+++ b/tensorflow/core/platform/default/string_coding.h
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRING_CODING_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRING_CODING_H_
+
+// IWYU pragma: private, include "third_party/tensorflow/core/platform/tensor_coding.h"
+// IWYU pragma: friend third_party/tensorflow/core/platform/tensor_coding.h
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace port {
+
+// Encodes sequences of strings and serialized protocol buffers into a string.
+// Normal usage consists of zero or more calls to Append() and a single call to
+// Finalize().
+class StringListEncoder {
+ public:
+ explicit StringListEncoder(string* out) : out_(out) {}
+
+ // Encodes the given protocol buffer. This may not be called after Finalize().
+ void Append(const protobuf::MessageLite& m) {
+ core::PutVarint32(out_, m.ByteSize());
+ m.AppendToString(&rest_);
+ }
+
+ // Encodes the given string. This may not be called after Finalize().
+ void Append(const string& s) {
+ core::PutVarint32(out_, s.length());
+ strings::StrAppend(&rest_, s);
+ }
+
+ // Signals end of the encoding process. No other calls are allowed after this.
+ void Finalize() { strings::StrAppend(out_, rest_); }
+
+ private:
+ string* out_;
+ string rest_;
+};
+
+// Decodes a string into sequences of strings (which may represent serialized
+// protocol buffers). Normal usage involves a single call to ReadSizes() in
+// order to retrieve the length of all the strings in the sequence. For each
+// size returned a call to Data() is expected and will return the actual
+// string.
+class StringListDecoder {
+ public:
+ explicit StringListDecoder(const string& in) : reader_(in) {}
+
+ // Populates the given vector with the lengths of each string in the sequence
+ // being decoded. Upon returning the vector is guaranteed to contain as many
+ // elements as there are strings in the sequence.
+ bool ReadSizes(std::vector<uint32>* sizes) {
+ int64 total = 0;
+ for (auto& size : *sizes) {
+ if (!core::GetVarint32(&reader_, &size)) return false;
+ total += size;
+ }
+ if (total != static_cast<int64>(reader_.size())) {
+ return false;
+ }
+ return true;
+ }
+
+ // Returns a pointer to the next string in the sequence, then prepares for the
+ // next call by advancing 'size' characters in the sequence.
+ const char* Data(uint32 size) {
+ const char* data = reader_.data();
+ reader_.remove_prefix(size);
+ return data;
+ }
+
+ private:
+ StringPiece reader_;
+};
+
+std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out);
+std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in);
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRING_CODING_H_
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index fe7d0aa7d1..47c59d435b 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -26,7 +26,7 @@ limitations under the License.
#endif
#if defined(PLATFORM_WINDOWS)
#include <windows.h>
-#include "tensorflow/core/platform/windows/windows_file_system.h"
+#include "tensorflow/core/platform/windows/wide_char.h"
#define PATH_MAX MAX_PATH
#else
#include <unistd.h>
@@ -311,7 +311,7 @@ string Env::GetExecutablePath() {
HMODULE hModule = GetModuleHandleW(NULL);
WCHAR wc_file_path[MAX_PATH] = {0};
GetModuleFileNameW(hModule, wc_file_path, MAX_PATH);
- string file_path = WindowsFileSystem::WideCharToUtf8(wc_file_path);
+ string file_path = WideCharToUtf8(wc_file_path);
std::copy(file_path.begin(), file_path.end(), exe_path);
#else
CHECK_NE(-1, readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc
index 17dc81f7e0..84601de39a 100644
--- a/tensorflow/core/platform/tensor_coding.cc
+++ b/tensorflow/core/platform/tensor_coding.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_coding.h"
#include <vector>
-#include "tensorflow/core/framework/resource_handle.pb.h"
+
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -66,39 +66,5 @@ void CopyFromArray(string* s, const char* base, size_t bytes) {
s->assign(base, bytes);
}
-void EncodeResourceHandleList(const ResourceHandle* p, int64 n, string* out) {
- out->clear();
- string rest;
- ResourceHandleProto proto;
- for (int i = 0; i < n; ++i) {
- p[i].AsProto(&proto);
- core::PutVarint32(out, proto.ByteSize());
- proto.AppendToString(&rest);
- }
- *out += rest;
-}
-
-bool DecodeResourceHandleList(const string& in, ResourceHandle* ps, int64 n) {
- std::vector<uint32> sizes(n);
- StringPiece reader(in);
- int64 total = 0;
- for (auto& size : sizes) {
- if (!core::GetVarint32(&reader, &size)) return false;
- total += size;
- }
- if (total != static_cast<int64>(reader.size())) {
- return false;
- }
- ResourceHandleProto proto;
- for (int i = 0; i < n; ++i) {
- if (!proto.ParseFromArray(reader.data(), sizes[i])) {
- return false;
- }
- ps[i].FromProto(proto);
- reader.remove_prefix(sizes[i]);
- }
- return true;
-}
-
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h
index 19f53e6374..6c6d75830d 100644
--- a/tensorflow/core/platform/tensor_coding.h
+++ b/tensorflow/core/platform/tensor_coding.h
@@ -18,7 +18,6 @@ limitations under the License.
#define TENSORFLOW_PLATFORM_TENSOR_CODING_H_
#include <string>
-#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/platform.h"
@@ -26,6 +25,8 @@ limitations under the License.
#ifdef PLATFORM_GOOGLE
#include "tensorflow/core/platform/google/cord_coding.h"
+#else
+#include "tensorflow/core/platform/default/string_coding.h"
#endif
namespace tensorflow {
@@ -51,13 +52,6 @@ bool DecodeStringList(const string& src, string* strings, int64 n);
// Assigns base[0..bytes-1] to *s
void CopyFromArray(string* s, const char* base, size_t bytes);
-// Encodes a list of ResourceHandle protos in the given string.
-void EncodeResourceHandleList(const ResourceHandle* handles, int64 n,
- string* out);
-
-// Decodes a list of ResourceHandle protos from the given string.
-bool DecodeResourceHandleList(const string& in, ResourceHandle* ps, int64 n);
-
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/variant_coding.cc b/tensorflow/core/platform/variant_coding.cc
deleted file mode 100644
index 48c5389d29..0000000000
--- a/tensorflow/core/platform/variant_coding.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/platform/variant_coding.h"
-
-#include <vector>
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/variant_op_registry.h"
-#include "tensorflow/core/lib/core/coding.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace tensorflow {
-namespace port {
-
-void EncodeVariantList(const Variant* variant_array, int64 n, string* out) {
- out->clear();
- string rest;
- for (int i = 0; i < n; ++i) {
- string s;
- variant_array[i].Encode(&s);
- core::PutVarint32(out, s.length());
- strings::StrAppend(&rest, s);
- }
- strings::StrAppend(out, rest);
-}
-
-bool DecodeVariantList(const string& in, Variant* variant_array, int64 n) {
- std::vector<uint32> sizes(n);
- StringPiece reader(in);
- int64 total = 0;
- for (auto& size : sizes) {
- if (!core::GetVarint32(&reader, &size)) return false;
- total += size;
- }
- if (total != static_cast<int64>(reader.size())) {
- return false;
- }
-
- for (int i = 0; i < n; ++i) {
- if (variant_array[i].is_empty()) {
- variant_array[i] = VariantTensorDataProto();
- }
- string str(reader.data(), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
- if (!DecodeUnaryVariant(&variant_array[i])) {
- LOG(ERROR) << "Could not decode variant with type_name: \""
- << variant_array[i].TypeName()
- << "\". Perhaps you forgot to register a "
- "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
- return false;
- }
- reader.remove_prefix(sizes[i]);
- }
- return true;
-}
-
-} // end namespace port
-} // end namespace tensorflow
diff --git a/tensorflow/core/platform/variant_coding.h b/tensorflow/core/platform/variant_coding.h
deleted file mode 100644
index a971857e4a..0000000000
--- a/tensorflow/core/platform/variant_coding.h
+++ /dev/null
@@ -1,40 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_PLATFORM_VARIANT_CODING_H_
-#define TENSORFLOW_PLATFORM_VARIANT_CODING_H_
-
-#include "tensorflow/core/framework/variant.h"
-#include "tensorflow/core/framework/variant_encode_decode.h"
-
-#ifdef PLATFORM_GOOGLE
-#include "tensorflow/core/platform/google/variant_cord_coding.h"
-#endif
-
-namespace tensorflow {
-namespace port {
-
-// Encodes an array of Variant objects in to the given string.
-// `variant_array` is assumed to point to an array of `n` Variant objects.
-void EncodeVariantList(const Variant* variant_array, int64 n, string* out);
-
-// Decodes an array of Variant objects from the given string.
-// `variant_array` is assumed to point to an array of `n` Variant objects.
-bool DecodeVariantList(const string& in, Variant* variant_array, int64 n);
-
-} // end namespace port
-} // end namespace tensorflow
-
-#endif // TENSORFLOW_PLATFORM_VARIANT_CODING_H_
diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc
index 2f54f423b2..68ee3595a2 100644
--- a/tensorflow/core/platform/windows/env.cc
+++ b/tensorflow/core/platform/windows/env.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/platform/load_library.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/windows/wide_char.h"
#include "tensorflow/core/platform/windows/windows_file_system.h"
#pragma comment(lib, "Shlwapi.lib")
@@ -71,8 +72,8 @@ class WindowsEnv : public Env {
}
bool MatchPath(const string& path, const string& pattern) override {
- std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path));
- std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern));
+ std::wstring ws_path(Utf8ToWideChar(path));
+ std::wstring ws_pattern(Utf8ToWideChar(pattern));
return PathMatchSpecW(ws_path.c_str(), ws_pattern.c_str()) == TRUE;
}
@@ -125,7 +126,7 @@ class WindowsEnv : public Env {
std::string file_name = library_filename;
std::replace(file_name.begin(), file_name.end(), '/', '\\');
- std::wstring ws_file_name(WindowsFileSystem::Utf8ToWideChar(file_name));
+ std::wstring ws_file_name(Utf8ToWideChar(file_name));
HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL,
LOAD_WITH_ALTERED_SEARCH_PATH);
diff --git a/tensorflow/core/platform/windows/wide_char.h b/tensorflow/core/platform/windows/wide_char.h
new file mode 100644
index 0000000000..1b86abc3fa
--- /dev/null
+++ b/tensorflow/core/platform/windows/wide_char.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_WIDE_CHAR_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_WIDE_CHAR_H_
+
+#include <Windows.h>
+#include <string>
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+inline std::wstring Utf8ToWideChar(const string& utf8str) {
+ int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(),
+ (int)utf8str.size(), NULL, 0);
+ std::wstring ws_translated_str(size_required, 0);
+ MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(),
+ &ws_translated_str[0], size_required);
+ return ws_translated_str;
+}
+
+inline string WideCharToUtf8(const std::wstring& wstr) {
+ if (wstr.empty()) return std::string();
+ int size_required = WideCharToMultiByte(
+ CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL);
+ string utf8_translated_str(size_required, 0);
+ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(),
+ &utf8_translated_str[0], size_required, NULL, NULL);
+ return utf8_translated_str;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_WIDE_CHAR_H_
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index dc2efbeaf5..9079a5ccaa 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/posix/error.h"
#include "tensorflow/core/platform/windows/error.h"
+#include "tensorflow/core/platform/windows/wide_char.h"
#include "tensorflow/core/platform/windows/windows_file_system.h"
// TODO(mrry): Prevent this Windows.h #define from leaking out of our headers.
diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h
index ba0302f0fd..6b04720c68 100644
--- a/tensorflow/core/platform/windows/windows_file_system.h
+++ b/tensorflow/core/platform/windows/windows_file_system.h
@@ -64,25 +64,6 @@ class WindowsFileSystem : public FileSystem {
Status RenameFile(const string& src, const string& target) override;
string TranslateName(const string& name) const override { return name; }
-
- static std::wstring Utf8ToWideChar(const string& utf8str) {
- int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(),
- (int)utf8str.size(), NULL, 0);
- std::wstring ws_translated_str(size_required, 0);
- MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), (int)utf8str.size(),
- &ws_translated_str[0], size_required);
- return ws_translated_str;
- }
-
- static string WideCharToUtf8(const std::wstring& wstr) {
- if (wstr.empty()) return std::string();
- int size_required = WideCharToMultiByte(
- CP_UTF8, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL);
- string utf8_translated_str(size_required, 0);
- WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), (int)wstr.size(),
- &utf8_translated_str[0], size_required, NULL, NULL);
- return utf8_translated_str;
- }
};
class LocalWinFileSystem : public WindowsFileSystem {
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index a28ed06fb3..9a48f43a63 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -379,7 +379,17 @@ message ConfigProto {
// shared with other sessions.
bool isolate_session_state = 15;
- // Next: 16
+ // Everything inside Experimental is subject to change and is not subject
+ // to API stability guarantees in
+ // https://www.tensorflow.org/programmers_guide/version_compat.
+ message Experimental {
+ // Task name for group resolution.
+ string collective_group_leader = 1;
+ };
+
+ Experimental experimental = 16;
+
+ // Next: 17
};
// Options for a single Run() call.
@@ -414,6 +424,19 @@ message RunOptions {
// Enabling this option can slow down the Run() call.
bool report_tensor_allocations_upon_oom = 7;
+ // Everything inside Experimental is subject to change and is not subject
+ // to API stability guarantees in
+ // https://www.tensorflow.org/programmers_guide/version_compat.
+ message Experimental {
+ // If non-zero, declares that this graph is going to use collective
+ // ops and must synchronize step_ids with any other graph with this
+ // same group_key value (in a distributed computation where tasks
+ // run disjoint graphs).
+ int64 collective_graph_key = 1;
+ };
+
+ Experimental experimental = 8;
+
reserved 4;
}
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 45e57594e4..bbb25d6f3f 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -14,6 +14,11 @@ message AutoParallelOptions {
int32 num_replicas = 2;
}
+message ScopedAllocatorOptions {
+ // If present, only perform optimization for these ops.
+ repeated string enable_op = 1;
+}
+
message RewriterConfig {
// Graph rewriting is experimental and subject to change, not covered by any
// API stability guarantees.
@@ -67,6 +72,9 @@ message RewriterConfig {
Toggle debug_stripper = 11;
// If true, don't remove unnecessary ops from the graph
bool disable_model_pruning = 2;
+ // Try to allocate some independent Op outputs contiguously in order to
+ // merge or eliminate downstream Ops (off by default).
+ Toggle scoped_allocator_optimization = 15;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
@@ -115,6 +123,8 @@ message RewriterConfig {
// meta-optimizer or when manually specified through the optimizers field.
AutoParallelOptions auto_parallel = 5;
+ ScopedAllocatorOptions scoped_allocator_opts = 16;
+
// If non-empty, will use this as an alternative way to specify a list of
// optimizations to turn on and the order of the optimizations (replacing the
// meta-optimizer).
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 39dd66fa1f..a3bc2f422e 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -122,6 +122,14 @@ message RegisterGraphRequest {
// Field(s) used by TensorFlow Debugger (tfdbg).
DebugOptions debug_options = 5;
+
+ // If graph_def contains any collective ops this must be a positive
+ // integer used to coordinate execution with other graphs. All
+ // graphs in a distributed execution with the same
+ // collective_graph_key will coordinate to use the same step_id
+ // concurrently so that BufRendezvous entries will make the correct
+ // values accessible.
+ int64 collective_graph_key = 7;
}
message RegisterGraphResponse {
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index 8447028e38..42a4801dcb 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -31,26 +31,22 @@ limitations under the License.
namespace tensorflow {
+using Detail = StatsCalculator::Detail;
+
StatSummarizer::StatSummarizer(const StatSummarizerOptions& options)
- : options_(options) {}
+ : stats_calculator_(new StatsCalculator(options)) {}
StatSummarizer::StatSummarizer(const tensorflow::GraphDef& tensorflow_graph)
- : StatSummarizer(StatSummarizerOptions()) {}
+ : stats_calculator_(new StatsCalculator(StatSummarizerOptions())) {}
StatSummarizer::~StatSummarizer() {}
-void StatSummarizer::Reset() {
- run_total_us_.Reset();
- memory_.Reset();
- details_.clear();
-}
-
-void StatSummarizer::Validate(const Detail* detail,
+void StatSummarizer::Validate(const std::vector<TensorDescription>* outputs,
const NodeExecStats& ns) const {
- if (detail->outputs.size() != ns.output_size()) {
+ if (outputs->size() != ns.output_size()) {
LOG(WARNING) << "Number of outputs changed between runs for '"
- << ns.node_name() << "' - was " << detail->outputs.size()
- << ", now " << ns.output_size();
+ << ns.node_name() << "' - was " << outputs->size() << ", now "
+ << ns.output_size();
} else {
for (const auto& output : ns.output()) {
const int32 slot = output.slot();
@@ -58,7 +54,7 @@ void StatSummarizer::Validate(const Detail* detail,
// This is not a hard error for Switch ops, so just pass.
continue;
}
- const auto& stored = detail->outputs[slot];
+ const auto& stored = (*outputs)[slot];
const auto& current = output.tensor_description();
bool do_tensors_match =
@@ -129,6 +125,7 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
int64 first_node_start_us =
step_stats.dev_stats(0).node_stats(0).all_start_micros();
+ std::map<std::string, Detail> details;
int node_num = 0;
for (const auto& ds : step_stats.dev_stats()) {
@@ -172,7 +169,10 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
++node_num;
const int64 curr_time = ns.all_end_rel_micros();
curr_total_us += curr_time;
- auto result = details_.emplace(name, Detail());
+ auto result = details.emplace(name, Detail());
+ auto output_result =
+ outputs_.emplace(name, std::vector<TensorDescription>());
+ std::vector<TensorDescription>* outputs = &(output_result.first->second);
Detail* detail = &(result.first->second);
detail->start_us.UpdateStat(ns.all_start_micros() - first_node_start_us);
@@ -185,16 +185,15 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
detail->run_order = node_num;
- detail->outputs.resize(ns.output_size());
+ outputs->resize(ns.output_size());
for (const auto& output : ns.output()) {
const int32 slot = output.slot();
if ((slot < 0) || (slot >= ns.output_size())) {
// This is not a hard error for Switch ops, so just pass.
continue;
}
- detail->outputs[slot] = output.tensor_description();
+ (*outputs)[slot] = output.tensor_description();
}
-
detail->times_called = 0;
}
@@ -207,273 +206,22 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
mem_total += curr_node_mem;
++detail->times_called;
+ stats_calculator_->UpdateDetails(details);
- Validate(detail, ns);
- }
- }
-
- run_total_us_.UpdateStat(curr_total_us);
- memory_.UpdateStat(mem_total);
-}
-
-std::string StatSummarizer::ShortSummary() const {
- std::stringstream stream;
- stream << "Timings (microseconds): ";
- run_total_us_.OutputToStream(&stream);
- stream << std::endl;
-
- stream << "Memory (bytes): ";
- memory_.OutputToStream(&stream);
- stream << std::endl;
-
- stream << details_.size() << " nodes observed" << std::endl;
- return stream.str();
-}
-
-std::ostream& InitField(std::ostream& stream, int width) {
- stream << "\t" << std::right << std::setw(width) << std::fixed
- << std::setprecision(3);
- return stream;
-}
-
-std::string StatSummarizer::HeaderString(const string& title) const {
- std::stringstream stream;
-
- stream << "============================== " << title
- << " ==============================" << std::endl;
-
- InitField(stream, 24) << "[node type]";
- InitField(stream, 9) << "[start]";
- InitField(stream, 9) << "[first]";
- InitField(stream, 9) << "[avg ms]";
- InitField(stream, 8) << "[%]";
- InitField(stream, 8) << "[cdf%]";
- InitField(stream, 10) << "[mem KB]";
- InitField(stream, 9) << "[times called]";
- stream << "\t"
- << "[Name]";
- return stream.str();
-}
-
-std::string StatSummarizer::ColumnString(const Detail& detail,
- const int64 cumulative_stat_on_node,
- const Stat<int64>& stat) const {
- const double start_ms = detail.start_us.avg() / 1000.0;
- const double first_time_ms = detail.rel_end_us.first() / 1000.0;
- const double avg_time_ms = detail.rel_end_us.avg() / 1000.0;
- const double percentage = detail.rel_end_us.sum() * 100.0 / stat.sum();
- const double cdf_percentage = (cumulative_stat_on_node * 100.0f) / stat.sum();
- const int64 times_called = detail.times_called / num_runs();
-
- std::stringstream stream;
- InitField(stream, 24) << detail.type;
- InitField(stream, 9) << start_ms;
- InitField(stream, 9) << first_time_ms;
- InitField(stream, 9) << avg_time_ms;
- InitField(stream, 7) << percentage << "%";
- InitField(stream, 7) << cdf_percentage << "%";
- InitField(stream, 10) << detail.mem_used.newest() / 1000.0;
- InitField(stream, 9) << times_called;
- stream << "\t" << detail.name;
-
- return stream.str();
-}
-
-void StatSummarizer::OrderNodesByMetric(
- SortingMetric metric, std::vector<const Detail*>* details) const {
- std::priority_queue<std::pair<string, const Detail*>> sorted_list;
- const int num_nodes = details_.size();
-
- for (const auto& det : details_) {
- const Detail* detail = &(det.second);
- std::stringstream stream;
- stream << std::setw(20) << std::right << std::setprecision(10)
- << std::fixed;
-
- switch (metric) {
- case BY_NAME:
- stream << detail->name;
- break;
- case BY_RUN_ORDER:
- stream << num_nodes - detail->run_order;
- break;
- case BY_TIME:
- stream << detail->rel_end_us.avg();
- break;
- case BY_MEMORY:
- stream << detail->mem_used.avg();
- break;
- case BY_TYPE:
- stream << detail->type;
- break;
- default:
- stream << "";
- break;
+ Validate(outputs, ns);
}
-
- sorted_list.emplace(stream.str(), detail);
- }
-
- while (!sorted_list.empty()) {
- auto entry = sorted_list.top();
- sorted_list.pop();
- details->push_back(entry.second);
}
-}
-
-void StatSummarizer::ComputeStatsByType(
- std::map<string, int64>* node_type_map_count,
- std::map<string, int64>* node_type_map_time,
- std::map<string, int64>* node_type_map_memory,
- std::map<string, int64>* node_type_map_times_called,
- int64* accumulated_us) const {
- int64 run_count = run_total_us_.count();
-
- for (const auto& det : details_) {
- const string node_name = det.first;
- const Detail& detail = det.second;
-
- int64 curr_time_val =
- static_cast<int64>(detail.rel_end_us.sum() / run_count);
- *accumulated_us += curr_time_val;
- int64 curr_memory_val = detail.mem_used.newest();
-
- const string& node_type = detail.type;
-
- (*node_type_map_count)[node_type] += 1;
- (*node_type_map_time)[node_type] += curr_time_val;
- (*node_type_map_memory)[node_type] += curr_memory_val;
- (*node_type_map_times_called)[node_type] += detail.times_called / run_count;
- }
+ stats_calculator_->UpdateRunTotalUs(curr_total_us);
+ stats_calculator_->UpdateMemoryUsed(mem_total);
}
-std::string StatSummarizer::GetStatsByNodeType() const {
- std::stringstream stream;
-
- stream << "============================== Summary by node type "
- "=============================="
- << std::endl;
-
- LOG(INFO) << "Number of nodes executed: " << details_.size();
-
- std::map<string, int64> node_type_map_count;
- std::map<string, int64> node_type_map_time;
- std::map<string, int64> node_type_map_memory;
- std::map<string, int64> node_type_map_times_called;
- int64 accumulated_us = 0;
-
- ComputeStatsByType(&node_type_map_count, &node_type_map_time,
- &node_type_map_memory, &node_type_map_times_called,
- &accumulated_us);
-
- // Sort them.
- std::priority_queue<std::pair<int64, std::pair<string, int64>>> timings;
- for (const auto& node_type : node_type_map_time) {
- const int64 mem_used = node_type_map_memory[node_type.first];
- timings.emplace(node_type.second,
- std::pair<string, int64>(node_type.first, mem_used));
- }
-
- InitField(stream, 24) << "[Node type]";
- InitField(stream, 9) << "[count]";
- InitField(stream, 10) << "[avg ms]";
- InitField(stream, 11) << "[avg %]";
- InitField(stream, 11) << "[cdf %]";
- InitField(stream, 10) << "[mem KB]";
- InitField(stream, 10) << "[times called]";
- stream << std::endl;
-
- float cdf = 0.0f;
- while (!timings.empty()) {
- auto entry = timings.top();
- timings.pop();
-
- const string node_type = entry.second.first;
- const float memory = entry.second.second / 1000.0f;
-
- const int64 node_type_total_us = entry.first;
- const float time_per_run_ms = node_type_total_us / 1000.0f;
-
- const float percentage =
- ((entry.first / static_cast<float>(accumulated_us)) * 100.0f);
- cdf += percentage;
-
- InitField(stream, 24) << node_type;
- InitField(stream, 9) << node_type_map_count[node_type];
- InitField(stream, 10) << time_per_run_ms;
- InitField(stream, 10) << percentage << "%";
- InitField(stream, 10) << cdf << "%";
- InitField(stream, 10) << memory;
- InitField(stream, 9) << node_type_map_times_called[node_type];
- stream << std::endl;
- }
- stream << std::endl;
- return stream.str();
-}
-
-std::string StatSummarizer::GetStatsByMetric(const string& title,
- SortingMetric sorting_metric,
- int num_stats) const {
- std::vector<const Detail*> details;
- OrderNodesByMetric(sorting_metric, &details);
-
- double cumulative_stat_on_node = 0;
-
- std::stringstream stream;
- stream << HeaderString(title) << std::endl;
- int stat_num = 0;
- for (auto detail : details) {
- ++stat_num;
- if (num_stats > 0 && stat_num > num_stats) {
- break;
- }
-
- // TODO(andrewharp): Make this keep track of the particular metric for cdf.
- cumulative_stat_on_node += detail->rel_end_us.sum();
- stream << ColumnString(*detail, cumulative_stat_on_node, run_total_us_)
- << std::endl;
- }
- stream << std::endl;
- return stream.str();
-}
-
-std::string StatSummarizer::GetOutputString() const {
- std::stringstream stream;
- if (options_.show_run_order) {
- stream << GetStatsByMetric("Run Order", BY_RUN_ORDER,
- options_.run_order_limit);
- }
- if (options_.show_time) {
- stream << GetStatsByMetric("Top by Computation Time", BY_TIME,
- options_.time_limit);
- }
- if (options_.show_memory) {
- stream << GetStatsByMetric("Top by Memory Use", BY_MEMORY,
- options_.memory_limit);
- }
- if (options_.show_type) {
- stream << GetStatsByNodeType();
- }
- if (options_.show_summary) {
- stream << ShortSummary() << std::endl;
- }
- return stream.str();
-}
-
-void StatSummarizer::PrintStepStats() const {
- string output = GetOutputString();
- std::istringstream iss(output);
- for (std::string line; std::getline(iss, line);) {
- LOG(INFO) << line;
- }
-}
void StatSummarizer::PrintOutputs() const {
std::priority_queue<
std::pair<int64, const std::pair<const std::string, Detail>*>>
timings;
- for (const auto& entry : details_) {
+ for (const auto& entry : stats_calculator_->GetDetails()) {
timings.emplace(-entry.second.start_us.avg(), &entry);
}
@@ -481,10 +229,10 @@ void StatSummarizer::PrintOutputs() const {
while (!timings.empty()) {
auto entry = timings.top();
timings.pop();
- const Detail& detail = entry.second->second;
std::stringstream stream;
- stream << entry.second->first << "\t" << detail.outputs.size();
- for (const auto& tensor : detail.outputs) {
+ const auto detail_outputs = outputs_.at(entry.second->first);
+ stream << entry.second->first << "\t" << detail_outputs.size();
+ for (const auto& tensor : detail_outputs) {
stream << "\t" << DataTypeString(tensor.dtype());
stream << "\t" << tensor.shape().dim_size();
for (const auto& d : tensor.shape().dim()) {
diff --git a/tensorflow/core/util/stat_summarizer.h b/tensorflow/core/util/stat_summarizer.h
index 79fa63723e..173ed5cebc 100644
--- a/tensorflow/core/util/stat_summarizer.h
+++ b/tensorflow/core/util/stat_summarizer.h
@@ -13,20 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_STAT_SUMMARIZER_H_
-#define TENSORFLOW_UTIL_STAT_SUMMARIZER_H_
+#ifndef TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_H_
+#define TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_H_
#include <stdlib.h>
#include <cmath>
#include <limits>
#include <map>
+#include <memory>
#include <sstream>
#include <string>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/stat_summarizer_options.h"
+#include "tensorflow/core/util/stats_calculator.h"
namespace tensorflow {
@@ -34,103 +37,6 @@ class GraphDef;
class StepStats;
class NodeExecStats;
-template <typename ValueType, typename HighPrecisionValueType = double>
-class Stat {
- public:
- void UpdateStat(ValueType v) {
- if (count_ == 0) {
- first_ = v;
- }
-
- newest_ = v;
- max_ = std::max(v, max_);
- min_ = std::min(v, min_);
- ++count_;
- sum_ += v;
- squared_sum_ += static_cast<HighPrecisionValueType>(v) * v;
- }
-
- void Reset() { new (this) Stat<ValueType, HighPrecisionValueType>(); }
-
- bool empty() const { return count_ == 0; }
-
- ValueType first() const { return first_; }
-
- ValueType newest() const { return newest_; }
-
- ValueType max() const { return max_; }
-
- ValueType min() const { return min_; }
-
- int64 count() const { return count_; }
-
- ValueType sum() const { return sum_; }
-
- HighPrecisionValueType squared_sum() const { return squared_sum_; }
-
- bool all_same() const { return (count_ == 0 || min_ == max_); }
-
- HighPrecisionValueType avg() const {
- return empty() ? std::numeric_limits<ValueType>::quiet_NaN()
- : static_cast<HighPrecisionValueType>(sum_) / count_;
- }
-
- ValueType std_deviation() const {
- return all_same() ? 0 : sqrt(squared_sum_ / count_ - avg() * avg());
- }
-
- void OutputToStream(std::ostream* stream) const {
- if (empty()) {
- *stream << "count=0";
- } else if (all_same()) {
- *stream << "count=" << count_ << " curr=" << newest_;
- if (count_ > 1) *stream << "(all same)";
- } else {
- *stream << "count=" << count_ << " first=" << first_
- << " curr=" << newest_ << " min=" << min_ << " max=" << max_
- << " avg=" << avg() << " std=" << std_deviation();
- }
- }
-
- friend std::ostream& operator<<(std::ostream& stream,
- const Stat<ValueType>& stat) {
- stat.OutputToStream(&stream);
- return stream;
- }
-
- private:
- ValueType first_ = 0;
- ValueType newest_ = 0;
- ValueType max_ = std::numeric_limits<ValueType>::min();
- ValueType min_ = std::numeric_limits<ValueType>::max();
- int64 count_ = 0;
- ValueType sum_ = 0;
- HighPrecisionValueType squared_sum_ = 0;
-};
-
-// Used to control the output of the statistics summarizer;
-class StatSummarizerOptions {
- public:
- StatSummarizerOptions()
- : show_run_order(true),
- run_order_limit(0),
- show_time(true),
- time_limit(10),
- show_memory(true),
- memory_limit(10),
- show_type(true),
- show_summary(true) {}
-
- bool show_run_order;
- int run_order_limit;
- bool show_time;
- int time_limit;
- bool show_memory;
- int memory_limit;
- bool show_type;
- bool show_summary;
-};
-
// A StatSummarizer assists in performance analysis of Graph executions.
//
// It summarizes time spent executing (on GPU/CPU), memory used etc. across
@@ -140,14 +46,6 @@ class StatSummarizerOptions {
// See tensorflow/tools/benchmark/benchmark_model.cc for an example usage.
class StatSummarizer {
public:
- enum SortingMetric {
- BY_NAME,
- BY_RUN_ORDER,
- BY_TIME,
- BY_MEMORY,
- BY_TYPE,
- };
-
explicit StatSummarizer(const StatSummarizerOptions& options);
// Deprecated: Use StatSummarizer(const StatSummarizerOptions&) instead. The
@@ -161,65 +59,58 @@ class StatSummarizer {
// Returns a string detailing the accumulated runtime stats in a tab-separated
// format which can be pasted into a spreadsheet for further analysis.
- std::string GetOutputString() const;
+ std::string GetOutputString() const {
+ return stats_calculator_->GetOutputString();
+ }
- std::string ShortSummary() const;
+ std::string ShortSummary() const {
+ return stats_calculator_->GetShortSummary();
+ }
// Prints the string returned by GetOutputString().
- void PrintStepStats() const;
+ void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
// Prints the output tensor sizes and types for each node.
void PrintOutputs() const;
- void ComputeStatsByType(std::map<string, int64>* node_type_map_count,
- std::map<string, int64>* node_type_map_time,
- std::map<string, int64>* node_type_map_memory,
- std::map<string, int64>* node_type_map_times_called,
- int64* accumulated_us) const;
+ void ComputeStatsByType(
+ std::map<std::string, int64_t>* node_type_map_count,
+ std::map<std::string, int64_t>* node_type_map_time,
+ std::map<std::string, int64_t>* node_type_map_memory,
+ std::map<std::string, int64_t>* node_type_map_times_called,
+ int64_t* accumulated_us) const {
+ stats_calculator_->ComputeStatsByType(
+ node_type_map_count, node_type_map_time, node_type_map_memory,
+ node_type_map_times_called, accumulated_us);
+ }
- std::string GetStatsByNodeType() const;
+ std::string GetStatsByNodeType() const {
+ return stats_calculator_->GetStatsByNodeType();
+ }
std::string GetStatsByMetric(const string& title,
- SortingMetric sorting_metric,
- int num_stats) const;
-
- void Reset();
+ StatsCalculator::SortingMetric sorting_metric,
+ int num_stats) const {
+ return stats_calculator_->GetStatsByMetric(title, sorting_metric,
+ num_stats);
+ }
- // Returns number of runs.
- int num_runs() const { return static_cast<int>(run_total_us_.count()); }
+ int num_runs() const { return stats_calculator_->num_runs(); }
// Returns stats of total microseconds spent by all nodes in each run.
- const Stat<int64>& run_total_us() const { return run_total_us_; }
+ const Stat<int64_t>& run_total_us() const {
+ return stats_calculator_->run_total_us();
+ }
private:
- struct Detail {
- string name;
- string type;
- int64 run_order;
- Stat<int64> start_us;
- Stat<int64> rel_end_us;
- Stat<int64> mem_used;
- std::vector<TensorDescription> outputs;
- int64 times_called;
- };
-
- void Validate(const Detail* detail, const NodeExecStats& ns) const;
-
- void OrderNodesByMetric(SortingMetric sorting_metric,
- std::vector<const Detail*>* details) const;
-
- std::string HeaderString(const string& title) const;
- std::string ColumnString(const Detail& detail,
- const int64 cumulative_stat_on_node,
- const Stat<int64>& stat) const;
-
- Stat<int64> run_total_us_;
- Stat<int64> memory_;
-
- std::map<std::string, Detail> details_;
- StatSummarizerOptions options_;
+ void Validate(const std::vector<TensorDescription>* outputs,
+ const NodeExecStats& ns) const;
+
+ std::map<std::string, std::vector<TensorDescription> > outputs_;
+
+ std::unique_ptr<StatsCalculator> stats_calculator_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_STAT_SUMMARIZER_H_
+#endif // TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_H_
diff --git a/tensorflow/core/util/stat_summarizer_options.h b/tensorflow/core/util/stat_summarizer_options.h
new file mode 100644
index 0000000000..578020676b
--- /dev/null
+++ b/tensorflow/core/util/stat_summarizer_options.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_OPTIONS_H_
+#define TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_OPTIONS_H_
+namespace tensorflow {
+// Used to control the output of the statistics summarizer;
+class StatSummarizerOptions {
+ public:
+ StatSummarizerOptions()
+ : show_run_order(true),
+ run_order_limit(0),
+ show_time(true),
+ time_limit(10),
+ show_memory(true),
+ memory_limit(10),
+ show_type(true),
+ show_summary(true) {}
+
+ bool show_run_order;
+ int run_order_limit;
+ bool show_time;
+ int time_limit;
+ bool show_memory;
+ int memory_limit;
+ bool show_type;
+ bool show_summary;
+};
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_STAT_SUMMARIZER_OPTIONS_H_
diff --git a/tensorflow/core/util/stats_calculator.cc b/tensorflow/core/util/stats_calculator.cc
new file mode 100644
index 0000000000..20353ec76e
--- /dev/null
+++ b/tensorflow/core/util/stats_calculator.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/stats_calculator.h"
+
+#include <iomanip>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+StatsCalculator::StatsCalculator(const StatSummarizerOptions& options)
+ : options_(options) {}
+
+std::string StatsCalculator::GetShortSummary() const {
+ std::stringstream stream;
+ stream << "Timings (microseconds): ";
+ run_total_us_.OutputToStream(&stream);
+ stream << std::endl;
+
+ stream << "Memory (bytes): ";
+ memory_.OutputToStream(&stream);
+ stream << std::endl;
+
+ stream << details_.size() << " nodes observed" << std::endl;
+ return stream.str();
+}
+
+std::ostream& InitField(std::ostream& stream, int width) {
+ stream << "\t" << std::right << std::setw(width) << std::fixed
+ << std::setprecision(3);
+ return stream;
+}
+
+std::string StatsCalculator::HeaderString(const std::string& title) const {
+ std::stringstream stream;
+
+ stream << "============================== " << title
+ << " ==============================" << std::endl;
+
+ InitField(stream, 24) << "[node type]";
+ InitField(stream, 9) << "[start]";
+ InitField(stream, 9) << "[first]";
+ InitField(stream, 9) << "[avg ms]";
+ InitField(stream, 8) << "[%]";
+ InitField(stream, 8) << "[cdf%]";
+ InitField(stream, 10) << "[mem KB]";
+ InitField(stream, 9) << "[times called]";
+ stream << "\t"
+ << "[Name]";
+ return stream.str();
+}
+
+std::string StatsCalculator::ColumnString(const Detail& detail,
+ const int64_t cumulative_stat_on_node,
+ const Stat<int64_t>& stat) const {
+ const double start_ms = detail.start_us.avg() / 1000.0;
+ const double first_time_ms = detail.rel_end_us.first() / 1000.0;
+ const double avg_time_ms = detail.rel_end_us.avg() / 1000.0;
+ const double percentage = detail.rel_end_us.sum() * 100.0 / stat.sum();
+ const double cdf_percentage = (cumulative_stat_on_node * 100.0f) / stat.sum();
+ const int64_t times_called = detail.times_called / num_runs();
+
+ std::stringstream stream;
+ InitField(stream, 24) << detail.type;
+ InitField(stream, 9) << start_ms;
+ InitField(stream, 9) << first_time_ms;
+ InitField(stream, 9) << avg_time_ms;
+ InitField(stream, 7) << percentage << "%";
+ InitField(stream, 7) << cdf_percentage << "%";
+ InitField(stream, 10) << detail.mem_used.newest() / 1000.0;
+ InitField(stream, 9) << times_called;
+ stream << "\t" << detail.name;
+
+ return stream.str();
+}
+
+void StatsCalculator::OrderNodesByMetric(
+ SortingMetric metric, std::vector<const Detail*>* details) const {
+ std::priority_queue<std::pair<string, const Detail*>> sorted_list;
+ const int num_nodes = details_.size();
+
+ for (const auto& det : details_) {
+ const Detail* detail = &(det.second);
+ std::stringstream stream;
+ stream << std::setw(20) << std::right << std::setprecision(10)
+ << std::fixed;
+
+ switch (metric) {
+ case BY_NAME:
+ stream << detail->name;
+ break;
+ case BY_RUN_ORDER:
+ stream << num_nodes - detail->run_order;
+ break;
+ case BY_TIME:
+ stream << detail->rel_end_us.avg();
+ break;
+ case BY_MEMORY:
+ stream << detail->mem_used.avg();
+ break;
+ case BY_TYPE:
+ stream << detail->type;
+ break;
+ default:
+ stream << "";
+ break;
+ }
+
+ sorted_list.emplace(stream.str(), detail);
+ }
+
+ while (!sorted_list.empty()) {
+ auto entry = sorted_list.top();
+ sorted_list.pop();
+ details->push_back(entry.second);
+ }
+}
+
+void StatsCalculator::ComputeStatsByType(
+ std::map<std::string, int64_t>* node_type_map_count,
+ std::map<std::string, int64_t>* node_type_map_time,
+ std::map<std::string, int64_t>* node_type_map_memory,
+ std::map<std::string, int64_t>* node_type_map_times_called,
+ int64_t* accumulated_us) const {
+ int64_t run_count = run_total_us_.count();
+
+ for (const auto& det : details_) {
+ const string node_name = det.first;
+ const Detail& detail = det.second;
+
+ int64_t curr_time_val =
+ static_cast<int64_t>(detail.rel_end_us.sum() / run_count);
+ *accumulated_us += curr_time_val;
+
+ int64_t curr_memory_val = detail.mem_used.newest();
+
+ const string& node_type = detail.type;
+
+ (*node_type_map_count)[node_type] += 1;
+ (*node_type_map_time)[node_type] += curr_time_val;
+ (*node_type_map_memory)[node_type] += curr_memory_val;
+ (*node_type_map_times_called)[node_type] += detail.times_called / run_count;
+ }
+}
+
+std::string StatsCalculator::GetStatsByNodeType() const {
+ std::stringstream stream;
+
+ stream << "============================== Summary by node type "
+ "=============================="
+ << std::endl;
+
+ LOG(INFO) << "Number of nodes executed: " << details_.size();
+
+ std::map<std::string, int64_t> node_type_map_count;
+ std::map<std::string, int64_t> node_type_map_time;
+ std::map<std::string, int64_t> node_type_map_memory;
+ std::map<std::string, int64_t> node_type_map_times_called;
+ int64_t accumulated_us = 0;
+
+ ComputeStatsByType(&node_type_map_count, &node_type_map_time,
+ &node_type_map_memory, &node_type_map_times_called,
+ &accumulated_us);
+
+ // Sort them.
+ std::priority_queue<std::pair<int64_t, std::pair<string, int64_t>>> timings;
+ for (const auto& node_type : node_type_map_time) {
+ const int64_t mem_used = node_type_map_memory[node_type.first];
+ timings.emplace(node_type.second,
+ std::pair<string, int64_t>(node_type.first, mem_used));
+ }
+
+ InitField(stream, 24) << "[Node type]";
+ InitField(stream, 9) << "[count]";
+ InitField(stream, 10) << "[avg ms]";
+ InitField(stream, 11) << "[avg %]";
+ InitField(stream, 11) << "[cdf %]";
+ InitField(stream, 10) << "[mem KB]";
+ InitField(stream, 10) << "[times called]";
+ stream << std::endl;
+
+ float cdf = 0.0f;
+ while (!timings.empty()) {
+ auto entry = timings.top();
+ timings.pop();
+
+ const string node_type = entry.second.first;
+ const float memory = entry.second.second / 1000.0f;
+
+ const int64_t node_type_total_us = entry.first;
+ const float time_per_run_ms = node_type_total_us / 1000.0f;
+
+ const float percentage =
+ ((entry.first / static_cast<float>(accumulated_us)) * 100.0f);
+ cdf += percentage;
+
+ InitField(stream, 24) << node_type;
+ InitField(stream, 9) << node_type_map_count[node_type];
+ InitField(stream, 10) << time_per_run_ms;
+ InitField(stream, 10) << percentage << "%";
+ InitField(stream, 10) << cdf << "%";
+ InitField(stream, 10) << memory;
+ InitField(stream, 9) << node_type_map_times_called[node_type];
+ stream << std::endl;
+ }
+ stream << std::endl;
+ return stream.str();
+}
+
+std::string StatsCalculator::GetStatsByMetric(const std::string& title,
+ SortingMetric sorting_metric,
+ int num_stats) const {
+ std::vector<const Detail*> details;
+ OrderNodesByMetric(sorting_metric, &details);
+
+ double cumulative_stat_on_node = 0;
+
+ std::stringstream stream;
+ stream << HeaderString(title) << std::endl;
+ int stat_num = 0;
+ for (auto detail : details) {
+ ++stat_num;
+ if (num_stats > 0 && stat_num > num_stats) {
+ break;
+ }
+
+ // TODO(andrewharp): Make this keep track of the particular metric for cdf.
+ cumulative_stat_on_node += detail->rel_end_us.sum();
+ stream << ColumnString(*detail, cumulative_stat_on_node, run_total_us_)
+ << std::endl;
+ }
+ stream << std::endl;
+ return stream.str();
+}
+
+std::string StatsCalculator::GetOutputString() const {
+ std::stringstream stream;
+ if (options_.show_run_order) {
+ stream << GetStatsByMetric("Run Order", BY_RUN_ORDER,
+ options_.run_order_limit);
+ }
+ if (options_.show_time) {
+ stream << GetStatsByMetric("Top by Computation Time", BY_TIME,
+ options_.time_limit);
+ }
+ if (options_.show_memory) {
+ stream << GetStatsByMetric("Top by Memory Use", BY_MEMORY,
+ options_.memory_limit);
+ }
+ if (options_.show_type) {
+ stream << GetStatsByNodeType();
+ }
+ if (options_.show_summary) {
+ stream << GetShortSummary() << std::endl;
+ }
+ return stream.str();
+}
+
+void StatsCalculator::PrintStepStats() const {
+ string output = GetOutputString();
+ std::istringstream iss(output);
+ for (std::string line; std::getline(iss, line);) {
+ LOG(INFO) << line;
+ }
+}
+
+void StatsCalculator::UpdateDetails(
+ const std::map<std::string, Detail>& details) {
+ details_.insert(details.begin(), details.end());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h
new file mode 100644
index 0000000000..a1033465fb
--- /dev/null
+++ b/tensorflow/core/util/stats_calculator.h
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_STATS_CALCULATOR_H_
+#define TENSORFLOW_CORE_UTIL_STATS_CALCULATOR_H_
+
+#include <stdlib.h>
+
+#include <cmath>
+#include <limits>
+#include <map>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/util/stat_summarizer_options.h"
+
+namespace tensorflow {
+
+template <typename ValueType, typename HighPrecisionValueType = double>
+class Stat {
+ public:
+ void UpdateStat(ValueType v) {
+ if (count_ == 0) {
+ first_ = v;
+ }
+
+ newest_ = v;
+ max_ = std::max(v, max_);
+ min_ = std::min(v, min_);
+ ++count_;
+ sum_ += v;
+ squared_sum_ += static_cast<HighPrecisionValueType>(v) * v;
+ }
+
+ void Reset() { new (this) Stat<ValueType, HighPrecisionValueType>(); }
+
+ bool empty() const { return count_ == 0; }
+
+ ValueType first() const { return first_; }
+
+ ValueType newest() const { return newest_; }
+
+ ValueType max() const { return max_; }
+
+ ValueType min() const { return min_; }
+
+ int64_t count() const { return count_; }
+
+ ValueType sum() const { return sum_; }
+
+ HighPrecisionValueType squared_sum() const { return squared_sum_; }
+
+ bool all_same() const { return (count_ == 0 || min_ == max_); }
+
+ HighPrecisionValueType avg() const {
+ return empty() ? std::numeric_limits<ValueType>::quiet_NaN()
+ : static_cast<HighPrecisionValueType>(sum_) / count_;
+ }
+
+ ValueType std_deviation() const {
+ return all_same() ? 0 : sqrt(squared_sum_ / count_ - avg() * avg());
+ }
+
+ void OutputToStream(std::ostream* stream) const {
+ if (empty()) {
+ *stream << "count=0";
+ } else if (all_same()) {
+ *stream << "count=" << count_ << " curr=" << newest_;
+ if (count_ > 1) *stream << "(all same)";
+ } else {
+ *stream << "count=" << count_ << " first=" << first_
+ << " curr=" << newest_ << " min=" << min_ << " max=" << max_
+ << " avg=" << avg() << " std=" << std_deviation();
+ }
+ }
+
+ friend std::ostream& operator<<(std::ostream& stream,
+ const Stat<ValueType>& stat) {
+ stat.OutputToStream(&stream);
+ return stream;
+ }
+
+ private:
+ ValueType first_ = 0;
+ ValueType newest_ = 0;
+ ValueType max_ = std::numeric_limits<ValueType>::min();
+ ValueType min_ = std::numeric_limits<ValueType>::max();
+ int64_t count_ = 0;
+ ValueType sum_ = 0;
+ HighPrecisionValueType squared_sum_ = 0;
+};
+
+// A StatsCalculator assists in performance analysis of Graph executions.
+//
+// It summarizes time spent executing (on GPU/CPU), memory used etc for
+// graph execution.
+//
+// For example usage see StatsSummarizer.
+class StatsCalculator {
+ public:
+ enum SortingMetric {
+ BY_NAME,
+ BY_RUN_ORDER,
+ BY_TIME,
+ BY_MEMORY,
+ BY_TYPE,
+ };
+
+ explicit StatsCalculator(const StatSummarizerOptions& options);
+
+ // Returns a string detailing the accumulated runtime stats in a tab-separated
+ // format which can be pasted into a spreadsheet for further analysis.
+ std::string GetOutputString() const;
+
+ std::string GetShortSummary() const;
+
+ // Prints the string returned by GetOutputString().
+ void PrintStepStats() const;
+
+ void ComputeStatsByType(
+ std::map<std::string, int64_t>* node_type_map_count,
+ std::map<std::string, int64_t>* node_type_map_time,
+ std::map<std::string, int64_t>* node_type_map_memory,
+ std::map<std::string, int64_t>* node_type_map_times_called,
+ int64_t* accumulated_us) const;
+
+ std::string GetStatsByNodeType() const;
+
+ std::string GetStatsByMetric(const std::string& title,
+ SortingMetric sorting_metric,
+ int num_stats) const;
+
+ // Returns number of runs.
+ int num_runs() const { return static_cast<int>(run_total_us_.count()); }
+
+ // Returns stats of total microseconds spent by all nodes in each run.
+ const Stat<int64_t>& run_total_us() const { return run_total_us_; }
+
+ void UpdateRunTotalUs(int64_t run_total_us) {
+ run_total_us_.UpdateStat(run_total_us);
+ }
+
+ void UpdateMemoryUsed(int64_t memory) { memory_.UpdateStat(memory); }
+
+ struct Detail {
+ std::string name;
+ std::string type;
+ int64_t run_order;
+ Stat<int64_t> start_us;
+ Stat<int64_t> rel_end_us;
+ Stat<int64_t> mem_used;
+ int64_t times_called;
+ };
+
+ const std::map<std::string, Detail>& GetDetails() const { return details_; }
+ void UpdateDetails(const std::map<std::string, Detail>& details);
+
+ private:
+ void OrderNodesByMetric(SortingMetric sorting_metric,
+ std::vector<const Detail*>* details) const;
+
+ std::string HeaderString(const std::string& title) const;
+ std::string ColumnString(const Detail& detail,
+ const int64_t cumulative_stat_on_node,
+ const Stat<int64_t>& stat) const;
+
+ Stat<int64_t> run_total_us_;
+ Stat<int64_t> memory_;
+
+ std::map<std::string, Detail> details_;
+ StatSummarizerOptions options_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_STATS_CALCULATOR_H_
diff --git a/tensorflow/docs_src/community/security.md b/tensorflow/docs_src/community/security.md
deleted file mode 100644
index 8d13c7a1ea..0000000000
--- a/tensorflow/docs_src/community/security.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Using TensorFlow Securely
-
-Before using TensorFlow, please take a look at our security model, list of
-recent security announcements, and ways you can report security issues to the
-TensorFlow team at the
-[https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](Using
-TensorFlow Securely) page on GitHub.
diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/get_started/datasets_quickstart.md
index c972e5e555..020e40dd3b 100644
--- a/tensorflow/docs_src/get_started/datasets_quickstart.md
+++ b/tensorflow/docs_src/get_started/datasets_quickstart.md
@@ -14,7 +14,7 @@ introduces the API by walking through two simple examples:
Taking slices from an array is the simplest way to get started with `tf.data`.
-The @{$get_started/premade_estimators$Premade Estimators} chapter describes
+The @{$premade_estimators$Premade Estimators} chapter describes
the following `train_input_fn`, from
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py),
to pipe the data into the Estimator:
@@ -377,7 +377,7 @@ Now you have the basic idea of how to efficiently load data into an
Estimator. Consider the following documents next:
-* @{$get_started/custom_estimators}, which demonstrates how to build your own
+* @{$custom_estimators}, which demonstrates how to build your own
custom `Estimator` model.
* The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates
how to experiment directly with `tf.data.Datasets` using TensorFlow's low
diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md
deleted file mode 100644
index d5a80e22c5..0000000000
--- a/tensorflow/docs_src/get_started/get_started_for_beginners.md
+++ /dev/null
@@ -1,751 +0,0 @@
-# Get Started with Graph Execution
-
-This document explains how to use machine learning to classify (categorize)
-Iris flowers by species. This document dives deeply into the TensorFlow
-code to do exactly that, explaining ML fundamentals along the way.
-
-If the following list describes you, then you are in the right place:
-
-* You know little to nothing about machine learning.
-* You want to learn how to write TensorFlow programs.
-* You can code (at least a little) in Python.
-
-If you are already familiar with basic machine learning concepts
-but are new to TensorFlow, read
-@{$premade_estimators$Getting Started with TensorFlow: for ML Experts}.
-
-If you'd like to learn a lot about the basics of Machine Learning,
-consider taking
-[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/).
-
-
-## The Iris classification problem
-
-Imagine you are a botanist seeking an automated way to classify each
-Iris flower you find. Machine learning provides many ways to classify flowers.
-For instance, a sophisticated machine learning program could classify flowers
-based on photographs. Our ambitions are more modest--we're going to classify
-Iris flowers based solely on the length and width of their
-[sepals](https://en.wikipedia.org/wiki/Sepal) and
-[petals](https://en.wikipedia.org/wiki/Petal).
-
-The Iris genus entails about 300 species, but our program will classify only
-the following three:
-
-* Iris setosa
-* Iris virginica
-* Iris versicolor
-
-<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%"
- alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
- src="../images/iris_three_species.jpg">
-</div>
-
-**From left to right,
-[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
-[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),
-[*Iris versicolor*](https://commons.wikimedia.org/w/index.php?curid=248095) (by
-[Dlanglois](https://commons.wikimedia.org/wiki/User:Dlanglois), CC BY-SA 3.0),
-and [*Iris virginica*](https://www.flickr.com/photos/33397993@N05/3352169862)
-(by [Frank Mayfield](https://www.flickr.com/photos/33397993@N05), CC BY-SA
-2.0).**
-<p>&nbsp;</p>
-
-Fortunately, someone has already created [a data set of 120 Iris
-flowers](https://en.wikipedia.org/wiki/Iris_flower_data_set)
-with the sepal and petal measurements. This data set has become
-one of the canonical introductions to machine learning classification problems.
-(The [MNIST database](https://en.wikipedia.org/wiki/MNIST_database),
-which contains handwritten digits, is another popular classification
-problem.) The first 5 entries of the Iris data set
-look as follows:
-
-| Sepal length | sepal width | petal length | petal width | species
-| --- | --- | --- | --- | ---
-|6.4 | 2.8 | 5.6 | 2.2 | 2
-|5.0 | 2.3 | 3.3 | 1.0 | 1
-|4.9 | 2.5 | 4.5 | 1.7 | 2
-|4.9 | 3.1 | 1.5 | 0.1 | 0
-|5.7 | 3.8 | 1.7 | 0.3 | 0
-
-Let's introduce some terms:
-
-* The last column (species) is called the
- [**label**](https://developers.google.com/machine-learning/glossary/#label);
- the first four columns are called
- [**features**](https://developers.google.com/machine-learning/glossary/#feature).
- Features are characteristics of an example, while the label is
- the thing we're trying to predict.
-
-* An [**example**](https://developers.google.com/machine-learning/glossary/#example)
- consists of the set of features and the label for one sample
- flower. The preceding table shows 5 examples from a data set of
- 120 examples.
-
-Each label is naturally a string (for example, "setosa"), but machine learning
-typically relies on numeric values. Therefore, someone mapped each string to
-a number. Here's the representation scheme:
-
-* 0 represents setosa
-* 1 represents versicolor
-* 2 represents virginica
-
-For a look at other examples of labels and examples, see the
-[ML Terminology section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/framing/ml-terminology).
-
-
-## Models and training
-
-A **model** is the relationship between features
-and the label. For the Iris problem, the model defines the relationship
-between the sepal and petal measurements and the predicted Iris species. Some
-simple models can be described with a few lines of algebra, but complex machine
-learning models have a large number of parameters that are difficult to
-summarize.
-
-Could you determine the relationship between the four features and the
-Iris species *without* using machine learning? That is, could you use
-traditional programming techniques (for example, a lot of conditional
-statements) to create a model? Maybe. You could play with the data set
-long enough to determine the right relationships of petal and sepal
-measurements to particular species. However, a good machine learning
-approach *determines the model for you*. That is, if you feed enough
-representative examples into the right machine learning model type, the program
-will determine the relationship between sepals, petals, and species.
-
-**Training** is the stage of machine learning in which the model is
-gradually optimized (learned). The Iris problem is an example
-of [**supervised machine
-learning**](https://developers.google.com/machine-learning/glossary/#supervised_machine_learning)
-in which a model is trained from examples that contain labels. (In
-[**unsupervised machine
-learning**](https://developers.google.com/machine-learning/glossary/#unsupervised_machine_learning),
-the examples don't contain labels. Instead, the model typically finds
-patterns among the features.)
-
-
-
-
-## Get the sample program
-
-Prior to playing with the sample code in this document, do the following:
-
-1. @{$install$Install TensorFlow}.
-2. If you installed TensorFlow with virtualenv or Anaconda, activate your
- TensorFlow environment.
-3. Install or upgrade pandas by issuing the following command:
-
- `pip install pandas`
-
-
-Take the following steps to get the sample program:
-
-1. Clone the TensorFlow Models repository from github by entering the following
- command:
-
- `git clone https://github.com/tensorflow/models`
-
-2. Change directory within that branch to the location containing the examples
- used in this document:
-
- `cd models/samples/core/get_started/`
-
-In that `get_started` directory, you'll find a program
-named `premade_estimator.py`.
-
-
-## Run the sample program
-
-You run TensorFlow programs as you would run any Python program. Therefore,
-issue the following command from a command line to
-run `premade_estimators.py`:
-
-``` bash
-python premade_estimator.py
-```
-
-Running the program should output a whole bunch of information ending with
-three prediction lines like the following:
-
-```None
-...
-Prediction is "Setosa" (99.6%), expected "Setosa"
-
-Prediction is "Versicolor" (99.8%), expected "Versicolor"
-
-Prediction is "Virginica" (97.9%), expected "Virginica"
-```
-
-If the program generates errors instead of predictions, ask yourself the
-following questions:
-
-* Did you install TensorFlow properly?
-* Are you using the correct version of TensorFlow? The `premade_estimators.py`
- program requires at least TensorFlow v1.4.
-* If you installed TensorFlow with virtualenv or Anaconda, did you activate
- the environment?
-
-
-
-## The TensorFlow programming stack
-
-As the following illustration shows, TensorFlow
-provides a programming stack consisting of multiple API layers:
-
-<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/tensorflow_programming_environment.png">
-</div>
-
-**The TensorFlow Programming Environment.**
-<p>&nbsp;</p>
-
-As you start writing TensorFlow programs, we strongly recommend focusing on
-the following two high-level APIs:
-
-* Estimators
-* Datasets
-
-Although we'll grab an occasional convenience function from other APIs,
-this document focuses on the preceding two APIs.
-
-
-## The program itself
-
-Thanks for your patience; let's dig into the code.
-The general outline of `premade_estimator.py`--and many other TensorFlow
-programs--is as follows:
-
-* Import and parse the data sets.
-* Create feature columns to describe the data.
-* Select the type of model
-* Train the model.
-* Evaluate the model's effectiveness.
-* Let the trained model make predictions.
-
-The following subsections detail each part.
-
-
-### Import and parse the data sets
-
-The Iris program requires the data from the following two .csv files:
-
-* `http://download.tensorflow.org/data/iris_training.csv`, which contains
- the training set.
-* `http://download.tensorflow.org/data/iris_test.csv`, which contains the
- test set.
-
-The **training set** contains the examples that we'll use to train the model;
-the **test set** contains the examples that we'll use to evaluate the trained
-model's effectiveness.
-
-The training set and test set started out as a
-single data set. Then, someone split the examples, with the majority going into
-the training set and the remainder going into the test set. Adding
-examples to the training set usually builds a better model; however, adding
-more examples to the test set enables us to better gauge the model's
-effectiveness. Regardless of the split, the examples in the test set
-must be separate from the examples in the training set. Otherwise, you can't
-accurately determine the model's effectiveness.
-
-The `premade_estimators.py` program relies on the `load_data` function
-in the adjacent [`iris_data.py`](
-https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py)
-file to read in and parse the training set and test set.
-Here is a heavily commented version of the function:
-
-```python
-TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
-TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
-
-CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
- 'PetalLength', 'PetalWidth', 'Species']
-
-...
-
-def load_data(label_name='Species'):
- """Parses the csv file in TRAIN_URL and TEST_URL."""
-
- # Create a local copy of the training set.
- train_path = tf.keras.utils.get_file(fname=TRAIN_URL.split('/')[-1],
- origin=TRAIN_URL)
- # train_path now holds the pathname: ~/.keras/datasets/iris_training.csv
-
- # Parse the local CSV file.
- train = pd.read_csv(filepath_or_buffer=train_path,
- names=CSV_COLUMN_NAMES, # list of column names
- header=0 # ignore the first row of the CSV file.
- )
- # train now holds a pandas DataFrame, which is data structure
- # analogous to a table.
-
- # 1. Assign the DataFrame's labels (the right-most column) to train_label.
- # 2. Delete (pop) the labels from the DataFrame.
- # 3. Assign the remainder of the DataFrame to train_features
- train_features, train_label = train, train.pop(label_name)
-
- # Apply the preceding logic to the test set.
- test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
- test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
- test_features, test_label = test, test.pop(label_name)
-
- # Return four DataFrames.
- return (train_features, train_label), (test_features, test_label)
-```
-
-Keras is an open-sourced machine learning library; `tf.keras` is a TensorFlow
-implementation of Keras. The `premade_estimator.py` program only accesses
-one `tf.keras` function; namely, the `tf.keras.utils.get_file` convenience
-function, which copies a remote CSV file to a local file system.
-
-The call to `load_data` returns two `(feature,label)` pairs, for the training
-and test sets respectively:
-
-```python
- # Call load_data() to parse the CSV file.
- (train_feature, train_label), (test_feature, test_label) = load_data()
-```
-
-Pandas is an open-source Python library leveraged by several
-TensorFlow functions. A pandas
-[**DataFrame**](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html)
-is a table with named columns headers and numbered rows.
-The features returned by `load_data` are packed in `DataFrames`.
-For example, the `test_feature` DataFrame looks as follows:
-
-```none
- SepalLength SepalWidth PetalLength PetalWidth
-0 5.9 3.0 4.2 1.5
-1 6.9 3.1 5.4 2.1
-2 5.1 3.3 1.7 0.5
-...
-27 6.7 3.1 4.7 1.5
-28 6.7 3.3 5.7 2.5
-29 6.4 2.9 4.3 1.3
-```
-
-
-### Describe the data
-
-A **feature column** is a data structure that tells your model
-how to interpret the data in each feature. In the Iris problem,
-we want the model to interpret the data in each
-feature as its literal floating-point value; that is, we want the
-model to interpret an input value like 5.4 as, well, 5.4. However,
-in other machine learning problems, it is often desirable to interpret
-data less literally. Using feature columns to
-interpret data is such a rich topic that we devote an entire
-@{$feature_columns$document} to it.
-
-From a code perspective, you build a list of `feature_column` objects by calling
-functions from the @{tf.feature_column} module. Each object describes an input
-to the model. To tell the model to interpret data as a floating-point value,
-call @{tf.feature_column.numeric_column}. In `premade_estimator.py`, all
-four features should be interpreted as literal floating-point values, so
-the code to create a feature column looks as follows:
-
-```python
-# Create feature columns for all features.
-my_feature_columns = []
-for key in train_x.keys():
- my_feature_columns.append(tf.feature_column.numeric_column(key=key))
-```
-
-Here is a less elegant, but possibly clearer, alternative way to
-encode the preceding block:
-
-```python
-my_feature_columns = [
- tf.feature_column.numeric_column(key='SepalLength'),
- tf.feature_column.numeric_column(key='SepalWidth'),
- tf.feature_column.numeric_column(key='PetalLength'),
- tf.feature_column.numeric_column(key='PetalWidth')
-]
-```
-
-
-### Select the type of model
-
-We need to select the kind of model that will be trained.
-Lots of model types exist; picking the ideal type takes experience.
-We've selected a neural network to solve the Iris problem. [**Neural
-networks**](https://developers.google.com/machine-learning/glossary/#neural_network)
-can find complex relationships between features and the label.
-A neural network is a highly-structured graph, organized into one or more
-[**hidden layers**](https://developers.google.com/machine-learning/glossary/#hidden_layer).
-Each hidden layer consists of one or more
-[**neurons**](https://developers.google.com/machine-learning/glossary/#neuron).
-There are several categories of neural networks.
-We'll be using a [**fully connected neural
-network**](https://developers.google.com/machine-learning/glossary/#fully_connected_layer),
-which means that the neurons in one layer take inputs from *every* neuron in
-the previous layer. For example, the following figure illustrates a
-fully connected neural network consisting of three hidden layers:
-
-* The first hidden layer contains four neurons.
-* The second hidden layer contains three neurons.
-* The third hidden layer contains two neurons.
-
-<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
-<img style="width:100%" src="../images/simple_dnn.svg">
-</div>
-
-**A neural network with three hidden layers.**
-<p>&nbsp;</p>
-
-For a more detailed introduction to neural networks, see the
-[Introduction to Neural Nets section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/introduction-to-neural-networks/anatomy).
-
-To specify a model type, instantiate an
-[**Estimator**](https://developers.google.com/machine-learning/glossary/#Estimators)
-class. TensorFlow provides two categories of Estimators:
-
-* [**pre-made
- Estimators**](https://developers.google.com/machine-learning/glossary/#pre-made_Estimator),
- which someone else has already written for you.
-* [**custom
- Estimators**](https://developers.google.com/machine-learning/glossary/#custom_estimator),
- which you must code yourself, at least partially.
-
-To implement a neural network, the `premade_estimators.py` program uses
-a pre-made Estimator named @{tf.estimator.DNNClassifier}. This Estimator
-builds a neural network that classifies examples. The following call
-instantiates `DNNClassifier`:
-
-```python
- classifier = tf.estimator.DNNClassifier(
- feature_columns=my_feature_columns,
- hidden_units=[10, 10],
- n_classes=3)
-```
-
-Use the `hidden_units` parameter to define the number of neurons
-in each hidden layer of the neural network. Assign this parameter
-a list. For example:
-
-```python
- hidden_units=[10, 10],
-```
-
-The length of the list assigned to `hidden_units` identifies the number of
-hidden layers (2, in this case).
-Each value in the list represents the number of neurons in a particular
-hidden layer (10 in the first hidden layer and 10 in the second hidden layer).
-To change the number of hidden layers or neurons, simply assign a different
-list to the `hidden_units` parameter.
-
-The ideal number of hidden layers and neurons depends on the problem
-and the data set. Like many aspects of machine learning,
-picking the ideal shape of the neural network requires some mixture
-of knowledge and experimentation.
-As a rule of thumb, increasing the number of hidden layers and neurons
-*typically* creates a more powerful model, which requires more data to
-train effectively.
-
-The `n_classes` parameter specifies the number of possible values that the
-neural network can predict. Since the Iris problem classifies 3 Iris species,
-we set `n_classes` to 3.
-
-The constructor for `tf.Estimator.DNNClassifier` takes an optional argument
-named `optimizer`, which our sample code chose not to specify. The
-[**optimizer**](https://developers.google.com/machine-learning/glossary/#optimizer)
-controls how the model will train. As you develop more expertise in machine
-learning, optimizers and
-[**learning
-rate**](https://developers.google.com/machine-learning/glossary/#learning_rate)
-will become very important.
-
-
-
-### Train the model
-
-Instantiating a `tf.Estimator.DNNClassifier` creates a framework for learning
-the model. Basically, we've wired a network but haven't yet let data flow
-through it. To train the neural network, call the Estimator object's `train`
-method. For example:
-
-```python
- classifier.train(
- input_fn=lambda:train_input_fn(train_feature, train_label, args.batch_size),
- steps=args.train_steps)
-```
-
-The `steps` argument tells `train` to stop training after the specified
-number of iterations. Increasing `steps` increases the amount of time
-the model will train. Counter-intuitively, training a model longer
-does not guarantee a better model. The default value of `args.train_steps`
-is 1000. The number of steps to train is a
-[**hyperparameter**](https://developers.google.com/machine-learning/glossary/#hyperparameter)
-you can tune. Choosing the right number of steps usually
-requires both experience and experimentation.
-
-The `input_fn` parameter identifies the function that supplies the
-training data. The call to the `train` method indicates that the
-`train_input_fn` function will supply the training data. Here's that
-method's signature:
-
-```python
-def train_input_fn(features, labels, batch_size):
-```
-
-We're passing the following arguments to `train_input_fn`:
-
-* `train_feature` is a Python dictionary in which:
- * Each key is the name of a feature.
- * Each value is an array containing the values for each example in the
- training set.
-* `train_label` is an array containing the values of the label for every
- example in the training set.
-* `args.batch_size` is an integer defining the [**batch
- size**](https://developers.google.com/machine-learning/glossary/#batch_size).
-
-The `train_input_fn` function relies on the **Dataset API**. This is a
-high-level TensorFlow API for reading data and transforming it into a form
-that the `train` method requires. The following call converts the
-input features and labels into a `tf.data.Dataset` object, which is the base
-class of the Dataset API:
-
-```python
- dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
-```
-
-The `tf.dataset` class provides many useful functions for preparing examples
-for training. The following line calls three of those functions:
-
-```python
- dataset = dataset.shuffle(buffer_size=1000).repeat(count=None).batch(batch_size)
-```
-
-Training works best if the training examples are in
-random order. To randomize the examples, call
-`tf.data.Dataset.shuffle`. Setting the `buffer_size` to a value
-larger than the number of examples (120) ensures that the data will
-be well shuffled.
-
-During training, the `train` method typically processes the
-examples multiple times. Calling the
-`tf.data.Dataset.repeat` method without any arguments ensures
-that the `train` method has an infinite supply of (now shuffled)
-training set examples.
-
-The `train` method processes a
-[**batch**](https://developers.google.com/machine-learning/glossary/#batch)
-of examples at a time.
-The `tf.data.Dataset.batch` method creates a batch by
-concatenating multiple examples.
-This program sets the default [**batch
-size**](https://developers.google.com/machine-learning/glossary/#batch_size)
-to 100, meaning that the `batch` method will concatenate groups of
-100 examples. The ideal batch size depends on the problem. As a rule
-of thumb, smaller batch sizes usually enable the `train` method to train
-the model faster at the expense (sometimes) of accuracy.
-
-The following `return` statement passes a batch of examples back to
-the caller (the `train` method).
-
-```python
- return dataset.make_one_shot_iterator().get_next()
-```
-
-
-### Evaluate the model
-
-**Evaluating** means determining how effectively the model makes
-predictions. To determine the Iris classification model's effectiveness,
-pass some sepal and petal measurements to the model and ask the model
-to predict what Iris species they represent. Then compare the model's
-prediction against the actual label. For example, a model that picked
-the correct species on half the input examples would have an
-[accuracy](https://developers.google.com/machine-learning/glossary/#accuracy)
-of 0.5. The following suggests a more effective model:
-
-
-<table>
- <tr>
- <th style="background-color:darkblue" colspan="5">
- Test Set</th>
- </tr>
- <tr>
- <th colspan="4">Features</th>
- <th colspan="1">Label</th>
- <th colspan="1">Prediction</th>
- </tr>
- <tr> <td>5.9</td> <td>3.0</td> <td>4.3</td> <td>1.5</td> <td>1</td>
- <td style="background-color:green">1</td></tr>
- <tr> <td>6.9</td> <td>3.1</td> <td>5.4</td> <td>2.1</td> <td>2</td>
- <td style="background-color:green">2</td></tr>
- <tr> <td>5.1</td> <td>3.3</td> <td>1.7</td> <td>0.5</td> <td>0</td>
- <td style="background-color:green">0</td></tr>
- <tr> <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td>1</td>
- <td style="background-color:red">2</td></tr>
- <tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td>
- <td style="background-color:green">1</td></tr>
-</table>
-
-**A model that is 80% accurate.**
-<p>&nbsp;</p>
-
-To evaluate a model's effectiveness, each Estimator provides an `evaluate`
-method. The `premade_estimator.py` program calls `evaluate` as follows:
-
-```python
-# Evaluate the model.
-eval_result = classifier.evaluate(
- input_fn=lambda:eval_input_fn(test_x, test_y, args.batch_size))
-
-print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
-```
-
-The call to `classifier.evaluate` is similar to the call to `classifier.train`.
-The biggest difference is that `classifier.evaluate` must get its examples
-from the test set rather than the training set. In other words, to
-fairly assess a model's effectiveness, the examples used to
-*evaluate* a model must be different from the examples used to *train*
-the model. The `eval_input_fn` function serves a batch of examples from
-the test set. Here's the `eval_input_fn` method:
-
-```python
-def eval_input_fn(features, labels=None, batch_size=None):
- """An input function for evaluation or prediction"""
- if labels is None:
- # No labels, use only features.
- inputs = features
- else:
- inputs = (features, labels)
-
- # Convert inputs to a tf.dataset object.
- dataset = tf.data.Dataset.from_tensor_slices(inputs)
-
- # Batch the examples
- assert batch_size is not None, "batch_size must not be None"
- dataset = dataset.batch(batch_size)
-
- # Return the read end of the pipeline.
- return dataset.make_one_shot_iterator().get_next()
-```
-
-In brief, `eval_input_fn` does the following when called by
-`classifier.evaluate`:
-
-1. Converts the features and labels from the test set to a `tf.dataset`
- object.
-2. Creates a batch of test set examples. (There's no need to shuffle
- or repeat the test set examples.)
-3. Returns that batch of test set examples to `classifier.evaluate`.
-
-Running this code yields the following output (or something close to it):
-
-```none
-Test set accuracy: 0.967
-```
-
-An accuracy of 0.967 implies that our trained model correctly classified 29
-out of the 30 Iris species in the test set.
-
-To get a deeper understanding of different metrics for evaluating
-models, see the
-[Classification section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/classification).
-
-
-### Predicting
-
-We've now trained a model and "proven" that it is good--but not
-perfect--at classifying Iris species. Now let's use the trained
-model to make some predictions on [**unlabeled
-examples**](https://developers.google.com/machine-learning/glossary/#unlabeled_example);
-that is, on examples that contain features but not a label.
-
-In real-life, the unlabeled examples could come from lots of different
-sources including apps, CSV files, and data feeds. For now, we're simply
-going to manually provide the following three unlabeled examples:
-
-```python
- predict_x = {
- 'SepalLength': [5.1, 5.9, 6.9],
- 'SepalWidth': [3.3, 3.0, 3.1],
- 'PetalLength': [1.7, 4.2, 5.4],
- 'PetalWidth': [0.5, 1.5, 2.1],
- }
-```
-
-Every Estimator provides a `predict` method, which `premade_estimator.py`
-calls as follows:
-
-```python
-predictions = classifier.predict(
- input_fn=lambda:eval_input_fn(predict_x,
- labels=None,
- batch_size=args.batch_size))
-```
-
-As with the `evaluate` method, our `predict` method also gathers examples
-from the `eval_input_fn` method.
-
-When doing predictions, we're *not* passing labels to `eval_input_fn`.
-Therefore, `eval_input_fn` does the following:
-
-1. Converts the features from the 3-element manual set we just created.
-2. Creates a batch of 3 examples from that manual set.
-3. Returns that batch of examples to `classifier.predict`.
-
-The `predict` method returns a python iterable, yielding a dictionary of
-prediction results for each example. This dictionary contains several keys.
-The `probabilities` key holds a list of three floating-point values,
-each representing the probability that the input example is a particular
-Iris species. For example, consider the following `probabilities` list:
-
-```none
-'probabilities': array([ 1.19127117e-08, 3.97069454e-02, 9.60292995e-01])
-```
-
-The preceding list indicates:
-
-* A negligible chance of the Iris being Setosa.
-* A 3.97% chance of the Iris being Versicolor.
-* A 96.0% chance of the Iris being Virginica.
-
-The `class_ids` key holds a one-element array that identifies the most
-probable species. For example:
-
-```none
-'class_ids': array([2])
-```
-
-The number `2` corresponds to Virginica. The following code iterates
-through the returned `predictions` to report on each prediction:
-
-``` python
-for pred_dict, expec in zip(predictions, expected):
- template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
-
- class_id = pred_dict['class_ids'][0]
- probability = pred_dict['probabilities'][class_id]
- print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec))
-```
-
-Running the program yields the following output:
-
-
-``` None
-...
-Prediction is "Setosa" (99.6%), expected "Setosa"
-
-Prediction is "Versicolor" (99.8%), expected "Versicolor"
-
-Prediction is "Virginica" (97.9%), expected "Virginica"
-```
-
-
-## Summary
-
-This document provides a short introduction to machine learning.
-
-Because `premade_estimators.py` relies on high-level APIs, much of the
-mathematical complexity in machine learning is hidden.
-If you intend to become more proficient in machine learning, we recommend
-ultimately learning more about [**gradient
-descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent),
-batching, and neural networks.
-
-We recommend reading the @{$feature_columns$Feature Columns} document next,
-which explains how to represent different kinds of data in machine learning.
diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md
index 578080bb59..232d2f1547 100644
--- a/tensorflow/docs_src/get_started/index.md
+++ b/tensorflow/docs_src/get_started/index.md
@@ -15,26 +15,8 @@ The easiest way to get started with TensorFlow is by using Eager Execution.
* @{$get_started/eager}, is for anyone new to machine learning or TensorFlow.
TensorFlow provides many APIs. The remainder of this section focuses on the
-Estimator API which provide scalable, high-performance models.
-To get started with Estimators, begin by reading one of the following documents:
-
- * @{$get_started/get_started_for_beginners}, which is aimed at readers
- new to machine learning.
- * @{$get_started/premade_estimators}, which is aimed at readers who have
- experience in machine learning.
-
-Then, read the following documents, which demonstrate the key features
-in the high-level APIs:
-
- * @{$get_started/checkpoints}, which explains how to save training progress
- and resume where you left off.
- * @{$get_started/feature_columns}, which shows how an
- Estimator can handle a variety of input data types without changes to the
- model.
- * @{$get_started/datasets_quickstart}, which introduces TensorFlow's
- input pipelines.
- * @{$get_started/custom_estimators}, which demonstrates how
- to build and train models you design yourself.
+Estimator API which provide scalable, high-performance models. See the
+@{$estimators} guide.
For more advanced users:
diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files
index 4c12f0d84b..e6cc8d5658 100644
--- a/tensorflow/docs_src/get_started/leftnav_files
+++ b/tensorflow/docs_src/get_started/leftnav_files
@@ -1,15 +1,4 @@
index.md
-### Beginners
eager.md
-get_started_for_beginners.md
-premade_estimators.md
-
-### Estimators
-get_started_for_beginners.md: For Beginners
-premade_estimators.md: Premade Estimators
->>>
-checkpoints.md
-feature_columns.md
datasets_quickstart.md
-custom_estimators.md
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 016e7bf1b9..29a867a9e3 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -403,10 +403,8 @@ writing TensorFlow programs:
If the system outputs an error message instead of a greeting, see
[Common installation problems](#common_installation_problems).
-If you are new to machine learning, we recommend the following:
-
-* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
-* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
+If you are new to machine learning, we recommend the
+[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course).
If you are experienced with machine learning but new to TensorFlow, see
@{$get_started/eager}.
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index a139a49661..6c4f5b85ab 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -157,10 +157,8 @@ TensorFlow programs:
If the system outputs an error message instead of a greeting, see [Common
installation problems](#common_installation_problems).
-If you are new to machine learning, we recommend the following:
-
-* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course)
-* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners}
+If you are new to machine learning, we recommend the
+[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course).
If you are experienced with machine learning but new to TensorFlow, see
@{$get_started/eager}.
diff --git a/tensorflow/docs_src/get_started/checkpoints.md b/tensorflow/docs_src/programmers_guide/checkpoints.md
index 8dfd91e3c8..8dfd91e3c8 100644
--- a/tensorflow/docs_src/get_started/checkpoints.md
+++ b/tensorflow/docs_src/programmers_guide/checkpoints.md
diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/programmers_guide/custom_estimators.md
index 275cda12bc..fb20b35c12 100644
--- a/tensorflow/docs_src/get_started/custom_estimators.md
+++ b/tensorflow/docs_src/programmers_guide/custom_estimators.md
@@ -5,7 +5,7 @@ This document introduces custom Estimators. In particular, this document
demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that
mimics the behavior of the pre-made Estimator
@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See
-the @{$get_started/premade_estimators$Pre-Made Estimators chapter} for details
+the @{$premade_estimators$Pre-Made Estimators chapter} for details
on the Iris problem.
To download and access the example code invoke the following two commands:
@@ -84,7 +84,7 @@ and a logits output layer.
## Write an Input function
Our custom Estimator implementation uses the same input function as our
-@{$get_started/premade_estimators$pre-made Estimator implementation}, from
+@{$premade_estimators$pre-made Estimator implementation}, from
[`iris_data.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py).
Namely:
@@ -106,8 +106,8 @@ This input function builds an input pipeline that yields batches of
## Create feature columns
-As detailed in the @{$get_started/premade_estimators$Premade Estimators} and
-@{$get_started/feature_columns$Feature Columns} chapters, you must define
+As detailed in the @{$premade_estimators$Premade Estimators} and
+@{$feature_columns$Feature Columns} chapters, you must define
your model's feature columns to specify how the model should use each feature.
Whether working with pre-made Estimators or custom Estimators, you define
feature columns in the same fashion.
@@ -145,7 +145,7 @@ to the constructor are in turn passed on to the `model_fn`. In
[`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py)
the following lines create the estimator and set the params to configure the
model. This configuration step is similar to how we configured the @{tf.estimator.DNNClassifier} in
-@{$get_started/premade_estimators}.
+@{$premade_estimators}.
```python
classifier = tf.estimator.Estimator(
@@ -489,7 +489,7 @@ configure your Estimator without modifying the code in the `model_fn`.
The rest of the code to train, evaluate, and generate predictions using our
Estimator is the same as in the
-@{$get_started/premade_estimators$Premade Estimators} chapter. For
+@{$premade_estimators$Premade Estimators} chapter. For
example, the following line will train the model:
```python
diff --git a/tensorflow/docs_src/programmers_guide/estimators.md b/tensorflow/docs_src/programmers_guide/estimators.md
index de830112e0..b13b47184d 100644
--- a/tensorflow/docs_src/programmers_guide/estimators.md
+++ b/tensorflow/docs_src/programmers_guide/estimators.md
@@ -133,7 +133,7 @@ The heart of every Estimator--whether pre-made or custom--is its
evaluation, and prediction. When you are using a pre-made Estimator,
someone else has already implemented the model function. When relying
on a custom Estimator, you must write the model function yourself. A
-@{$get_started/custom_estimators$companion document}
+@{$custom_estimators$companion document}
explains how to write the model function.
diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/programmers_guide/feature_columns.md
index f152813442..90f5c53a17 100644
--- a/tensorflow/docs_src/get_started/feature_columns.md
+++ b/tensorflow/docs_src/programmers_guide/feature_columns.md
@@ -5,7 +5,7 @@ intermediaries between raw data and Estimators. Feature columns are very rich,
enabling you to transform a diverse range of raw data into formats that
Estimators can use, allowing easy experimentation.
-In @{$get_started/premade_estimators$Premade Estimators}, we used the premade
+In @{$premade_estimators$Premade Estimators}, we used the premade
Estimator, @{tf.estimator.DNNClassifier$`DNNClassifier`} to train a model to
predict different types of Iris flowers from four input features. That example
created only numerical feature columns (of type
diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md
index 648d001bd3..9ebfd39c56 100644
--- a/tensorflow/docs_src/programmers_guide/index.md
+++ b/tensorflow/docs_src/programmers_guide/index.md
@@ -11,6 +11,23 @@ works. The units are as follows:
* @{$programmers_guide/datasets}, which explains how to
set up data pipelines to read data sets into your TensorFlow program.
+## Estimators
+
+* @{$estimators} provides an introduction.
+* @{$premade_estimators}, introduces Estimators for machine learning.
+* @{$custom_estimators}, which demonstrates how to build and train models you
+ design yourself.
+* @{$feature_columns}, which shows how an Estimator can handle a variety of input
+ data types without changes to the model.
+* @{$checkpoints}, which explains how to save training progress and resume where
+ you left off.
+
+## Accelerators
+
+ * @{$using_gpu} explains how TensorFlow assigns operations to
+ devices and how you can change the arrangement manually.
+ * @{$using_tpu} explains how to modify `Estimator` programs to run on a TPU.
+
## Low Level APIs
* @{$programmers_guide/low_level_intro}, which introduces the
@@ -32,13 +49,6 @@ works. The units are as follows:
* @{$programmers_guide/saved_model}, which
explains how to save and restore variables and models.
-## Accelerators
-
- * @{$using_gpu} explains how TensorFlow assigns operations to
- devices and how you can change the arrangement manually.
- * @{$using_tpu} explains how to modify `Estimator` programs to run on a TPU.
-
-
## ML Concepts
* @{$programmers_guide/embedding}, which introduces the concept
diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/programmers_guide/leftnav_files
index 7ac63bf2e0..331317446a 100644
--- a/tensorflow/docs_src/programmers_guide/leftnav_files
+++ b/tensorflow/docs_src/programmers_guide/leftnav_files
@@ -3,7 +3,17 @@ index.md
### High Level APIs
eager.md
datasets.md
-estimators.md
+
+### Estimators
+estimators.md: Introduction to Estimators
+premade_estimators.md
+custom_estimators.md
+feature_columns.md
+checkpoints.md
+
+### Accelerators
+using_gpu.md
+using_tpu.md
### Low Level APIs
low_level_intro.md
@@ -12,10 +22,6 @@ variables.md
graphs.md
saved_model.md
-### Accelerators
-using_gpu.md
-using_tpu.md
-
### ML Concepts
embedding.md
diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/programmers_guide/low_level_intro.md
index 05709ad10a..478e2bb70b 100644
--- a/tensorflow/docs_src/programmers_guide/low_level_intro.md
+++ b/tensorflow/docs_src/programmers_guide/low_level_intro.md
@@ -9,7 +9,7 @@ This guide gets you started programming in the low-level TensorFlow APIs
* Use high level components ([datasets](#datasets), [layers](#layers), and
[feature_columns](#feature_columns)) in this low level environment.
* Build your own training loop, instead of using the one
- @{$get_started/premade_estimators$provided by Estimators}.
+ @{$premade_estimators$provided by Estimators}.
We recommend using the higher level APIs to build models when possible.
Knowing TensorFlow Core is valuable for the following reasons:
@@ -398,7 +398,7 @@ and layer reuse impossible.
The easiest way to experiment with feature columns is using the
@{tf.feature_column.input_layer} function. This function only accepts
-@{$get_started/feature_columns$dense columns} as inputs, so to view the result
+@{$feature_columns$dense columns} as inputs, so to view the result
of a categorical column you must wrap it in an
@{tf.feature_column.indicator_column}. For example:
@@ -589,7 +589,7 @@ print(sess.run(y_pred))
To learn more about building models with TensorFlow consider the following:
-* @{$get_started/custom_estimators$Custom Estimators}, to learn how to build
+* @{$custom_estimators$Custom Estimators}, to learn how to build
customized models with TensorFlow. Your knowledge of TensorFlow Core will
help you understand and debug your own models.
diff --git a/tensorflow/docs_src/get_started/premade_estimators.md b/tensorflow/docs_src/programmers_guide/premade_estimators.md
index 15853bc0ab..f6dd75eaca 100644
--- a/tensorflow/docs_src/get_started/premade_estimators.md
+++ b/tensorflow/docs_src/programmers_guide/premade_estimators.md
@@ -177,13 +177,11 @@ other features so you can concentrate on your model. For more details see
An Estimator is any class derived from @{tf.estimator.Estimator}. TensorFlow
provides a collection of
-[pre-made Estimators](https://developers.google.com/machine-learning/glossary/#premade_Estimator)
+@{tf.estimator$pre-made Estimators}
(for example, `LinearRegressor`) to implement common ML algorithms. Beyond
those, you may write your own
-[custom Estimators](https://developers.google.com/machine-learning/glossary/#custom_Estimator).
-We recommend using pre-made Estimators when just getting started with
-TensorFlow. After gaining expertise with the pre-made Estimators, we recommend
-optimizing your model by creating your own custom Estimators.
+@{$custom_estimators$custom Estimators}.
+We recommend using pre-made Estimators when just getting started.
To write a TensorFlow program based on pre-made Estimators, you must perform the
following tasks:
@@ -289,7 +287,7 @@ for key in train_x.keys():
```
Feature columns can be far more sophisticated than those we're showing here. We
-detail feature columns @{$get_started/feature_columns$later on} in our Getting
+detail feature columns @{$feature_columns$later on} in our Getting
Started guide.
Now that we have the description of how we want the model to represent the raw
@@ -425,11 +423,10 @@ Pre-made Estimators are an effective way to quickly create standard models.
Now that you've gotten started writing TensorFlow programs, consider the
following material:
-* @{$get_started/checkpoints$Checkpoints} to learn how to save and restore
- models.
+* @{$checkpoints$Checkpoints} to learn how to save and restore models.
* @{$get_started/datasets_quickstart$Datasets} to learn more about importing
data into your
model.
-* @{$get_started/custom_estimators$Creating Custom Estimators} to learn how to
+* @{$custom_estimators$Creating Custom Estimators} to learn how to
write your own Estimator, customized for a particular problem.
diff --git a/tensorflow/docs_src/programmers_guide/using_tpu.md b/tensorflow/docs_src/programmers_guide/using_tpu.md
index 5e3e49d434..44aabf0557 100644
--- a/tensorflow/docs_src/programmers_guide/using_tpu.md
+++ b/tensorflow/docs_src/programmers_guide/using_tpu.md
@@ -22,8 +22,8 @@ Standard `Estimators` can drive models on CPU and GPUs. You must use
@{tf.contrib.tpu.TPUEstimator} to drive a model on TPUs.
Refer to TensorFlow's Getting Started section for an introduction to the basics
-of using a @{$get_started/premade_estimators$pre-made `Estimator`}, and
-@{$get_started/custom_estimators$custom `Estimator`s}.
+of using a @{$premade_estimators$pre-made `Estimator`}, and
+@{$custom_estimators$custom `Estimator`s}.
The `TPUEstimator` class differs somewhat from the `Estimator` class.
diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/kernel_methods.md
index 73e5c51057..205e2a2d2c 100644
--- a/tensorflow/docs_src/tutorials/kernel_methods.md
+++ b/tensorflow/docs_src/tutorials/kernel_methods.md
@@ -53,7 +53,7 @@ In order to feed data to a `tf.contrib.learn Estimator`, it is helpful to conver
it to Tensors. For this, we will use an `input function` which adds Ops to the
TensorFlow graph that, when executed, create mini-batches of Tensors to be used
downstream. For more background on input functions, check
-@{$get_started/premade_estimators#create_input_functions$this section on input functions}.
+@{$premade_estimators#create_input_functions$this section on input functions}.
In this example, we will use the `tf.train.shuffle_batch` Op which, besides
converting numpy arrays to Tensors, allows us to specify the batch_size and
whether to randomize the input every time the input_fn Ops are executed
diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md
index 496b1e4da9..0f17899dae 100644
--- a/tensorflow/docs_src/tutorials/layers.md
+++ b/tensorflow/docs_src/tutorials/layers.md
@@ -190,7 +190,7 @@ def cnn_model_fn(features, labels, mode):
The following sections (with headings corresponding to each code block above)
dive deeper into the `tf.layers` code used to create each layer, as well as how
to calculate loss, configure the training op, and generate predictions. If
-you're already experienced with CNNs and @{$get_started/custom_estimators$TensorFlow `Estimator`s},
+you're already experienced with CNNs and @{$custom_estimators$TensorFlow `Estimator`s},
and find the above code intuitive, you may want to skim these sections or just
skip ahead to ["Training and Evaluating the CNN MNIST Classifier"](#train_eval_mnist).
@@ -534,8 +534,8 @@ if mode == tf.estimator.ModeKeys.TRAIN:
```
> Note: For a more in-depth look at configuring training ops for Estimator model
-> functions, see @{$get_started/custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"}
-> in the @{$get_started/custom_estimators$"Creating Estimations in tf.estimator"} tutorial.
+> functions, see @{$custom_estimators#defining-the-training-op-for-the-model$"Defining the training op for the model"}
+> in the @{$custom_estimators$"Creating Estimations in tf.estimator"} tutorial.
### Add evaluation metrics
@@ -600,7 +600,7 @@ be saved (here, we specify the temp directory `/tmp/mnist_convnet_model`, but
feel free to change to another directory of your choice).
> Note: For an in-depth walkthrough of the TensorFlow `Estimator` API, see the
-> tutorial @{$get_started/custom_estimators$"Creating Estimators in tf.estimator."}
+> tutorial @{$custom_estimators$"Creating Estimators in tf.estimator."}
### Set Up a Logging Hook {#set_up_a_logging_hook}
@@ -719,7 +719,7 @@ Here, we've achieved an accuracy of 97.3% on our test data set.
To learn more about TensorFlow Estimators and CNNs in TensorFlow, see the
following resources:
-* @{$get_started/custom_estimators$Creating Estimators in tf.estimator}
+* @{$custom_estimators$Creating Estimators in tf.estimator}
provides an introduction to the TensorFlow Estimator API. It walks through
configuring an Estimator, writing a model function, calculating loss, and
defining a training op.
diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/linear.md
index 265ded877d..3f247ade26 100644
--- a/tensorflow/docs_src/tutorials/linear.md
+++ b/tensorflow/docs_src/tutorials/linear.md
@@ -17,7 +17,7 @@ tutorial walks through the code in greater detail.
To understand this overview it will help to have some familiarity
with basic machine learning concepts, and also with
-@{$get_started/premade_estimators$Estimators}.
+@{$premade_estimators$Estimators}.
[TOC]
diff --git a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/recurrent_quickdraw.md
index 5d83fbe2a3..1afd861738 100644
--- a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md
+++ b/tensorflow/docs_src/tutorials/recurrent_quickdraw.md
@@ -220,7 +220,7 @@ length 2.
### Defining the model
To define the model we create a new `Estimator`. If you want to read more about
-estimators, we recommend @{$get_started/custom_estimators$this tutorial}.
+estimators, we recommend @{$custom_estimators$this tutorial}.
To build the model, we:
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index e08f38969c..0dd3726948 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2674,29 +2674,50 @@ func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_uppe
return op.Output(0)
}
-// Clips tensor values to a specified min and max.
+// Returns the batched diagonal part of a batched tensor.
//
-// Given a tensor `t`, this operation returns a tensor of the same type and
-// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
-// Any values less than `clip_value_min` are set to `clip_value_min`. Any values
-// greater than `clip_value_max` are set to `clip_value_max`.
+// This operation returns a tensor with the `diagonal` part
+// of the batched `input`. The `diagonal` part is computed as follows:
+//
+// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
+// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where:
+//
+// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`.
+//
+// The input must be at least a matrix.
+//
+// For example:
+//
+// ```
+// # 'input' is [[[1, 0, 0, 0]
+// [0, 2, 0, 0]
+// [0, 0, 3, 0]
+// [0, 0, 0, 4]],
+// [[5, 0, 0, 0]
+// [0, 6, 0, 0]
+// [0, 0, 7, 0]
+// [0, 0, 0, 8]]]
+//
+// and input.shape = (2, 4, 4)
+//
+// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]]
+//
+// which has shape (2, 4)
+// ```
//
// Arguments:
-// t: A `Tensor`.
-// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
-// as `t`. The minimum value to clip by.
-// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
-// as `t`. The maximum value to clip by.
+// input: Rank `k` tensor where `k >= 2`.
//
-// Returns A clipped `Tensor` with the same shape as input 't'.
-func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) {
+// Returns The extracted diagonal(s) having shape
+// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`.
+func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "ClipByValue",
+ Type: "MatrixDiagPart",
Input: []tf.Input{
- t, clip_value_min, clip_value_max,
+ input,
},
}
op := scope.AddOperation(opspec)
@@ -4563,6 +4584,68 @@ func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Returns a batched matrix tensor with new batched diagonal values.
+//
+// Given `input` and `diagonal`, this operation returns a tensor with the
+// same shape and values as `input`, except for the main diagonal of the
+// innermost matrices. These will be overwritten by the values in `diagonal`.
+//
+// The output is computed as follows:
+//
+// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
+// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
+// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
+//
+// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
+// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
+//
+// Arguments:
+// input: Rank `k+1`, where `k >= 1`.
+// diagonal: Rank `k`, where `k >= 1`.
+//
+// Returns Rank `k+1`, with `output.shape = input.shape`.
+func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixSetDiag",
+ Input: []tf.Input{
+ input, diagonal,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the element-wise max of two SparseTensors.
+//
+// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
+//
+// Arguments:
+// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+// SparseTensor, in the canonical lexicographic ordering.
+// a_values: 1-D. `N` non-empty values corresponding to `a_indices`.
+// a_shape: 1-D. Shape of the input SparseTensor.
+// b_indices: counterpart to `a_indices` for the other operand.
+// b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
+// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
+//
+// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor.
+func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSparseMaximum",
+ Input: []tf.Input{
+ a_indices, a_values, a_shape, b_indices, b_values, b_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// OrderedMapClearAttr is an optional argument to OrderedMapClear.
type OrderedMapClearAttr func(optionalAttr)
@@ -7310,6 +7393,47 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
+// RandomPoissonAttr is an optional argument to RandomPoisson.
+type RandomPoissonAttr func(optionalAttr)
+
+// RandomPoissonSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed2(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Use RandomPoissonV2 instead.
+//
+// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
+func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoisson",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
@@ -7768,47 +7892,6 @@ func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf
return output_indices, output_values, output_shape
}
-// RandomPoissonAttr is an optional argument to RandomPoisson.
-type RandomPoissonAttr func(optionalAttr)
-
-// RandomPoissonSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed2(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Use RandomPoissonV2 instead.
-//
-// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
-func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoisson",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2.
type ResourceSparseApplyFtrlV2Attr func(optionalAttr)
@@ -10094,6 +10177,43 @@ func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, o
return op.Output(0)
}
+// Says whether the targets are in the top `K` predictions.
+//
+// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
+// prediction for the target class is among the top `k` predictions among
+// all predictions for example `i`. Note that the behavior of `InTopK` differs
+// from the `TopK` op in its handling of ties; if multiple classes have the
+// same prediction value and straddle the top-`k` boundary, all of those
+// classes are considered to be in the top `k`.
+//
+// More formally, let
+//
+// \\(predictions_i\\) be the predictions for all classes for example `i`,
+// \\(targets_i\\) be the target class for example `i`,
+// \\(out_i\\) be the output for example `i`,
+//
+// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
+//
+// Arguments:
+// predictions: A `batch_size` x `classes` tensor.
+// targets: A `batch_size` vector of class ids.
+// k: Number of top elements to look at for computing precision.
+//
+// Returns Computed precision at `k` as a `bool Tensor`.
+func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InTopKV2",
+ Input: []tf.Input{
+ predictions, targets, k,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
type DecodeAndCropJpegAttr func(optionalAttr)
@@ -10949,101 +11069,6 @@ func Fact(scope *Scope) (fact tf.Output) {
return op.Output(0)
}
-// AngleAttr is an optional argument to Angle.
-type AngleAttr func(optionalAttr)
-
-// AngleTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_FLOAT
-func AngleTout(value tf.DataType) AngleAttr {
- return func(m optionalAttr) {
- m["Tout"] = value
- }
-}
-
-// Returns the argument of a complex number.
-//
-// Given a tensor `input` of complex numbers, this operation returns a tensor of
-// type `float` that is the argument of each element in `input`. All elements in
-// `input` must be complex numbers of the form \\(a + bj\\), where *a*
-// is the real part and *b* is the imaginary part.
-//
-// The argument returned by this operation is of the form \\(atan2(b, a)\\).
-//
-// For example:
-//
-// ```
-// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-// tf.angle(input) ==> [2.0132, 1.056]
-// ```
-//
-// @compatibility(numpy)
-// Equivalent to np.angle.
-// @end_compatibility
-func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Angle",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// VarHandleOpAttr is an optional argument to VarHandleOp.
-type VarHandleOpAttr func(optionalAttr)
-
-// VarHandleOpContainer sets the optional container attribute to value.
-//
-// value: the container this variable is placed in.
-// If not specified, defaults to ""
-func VarHandleOpContainer(value string) VarHandleOpAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// VarHandleOpSharedName sets the optional shared_name attribute to value.
-//
-// value: the name by which this variable is referred to.
-// If not specified, defaults to ""
-func VarHandleOpSharedName(value string) VarHandleOpAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Creates a handle to a Variable resource.
-//
-// Arguments:
-// dtype: the type of this variable. Must agree with the dtypes
-// of all ops using this variable.
-// shape: The (possibly partially specified) shape of this variable.
-func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype, "shape": shape}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "VarHandleOp",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Elementwise computes the bitwise XOR of `x` and `y`.
//
// The result will have those bits set, that are different in `x` and `y`. The
@@ -18002,43 +18027,6 @@ func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padd
return op.Output(0)
}
-// Says whether the targets are in the top `K` predictions.
-//
-// This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
-// prediction for the target class is among the top `k` predictions among
-// all predictions for example `i`. Note that the behavior of `InTopK` differs
-// from the `TopK` op in its handling of ties; if multiple classes have the
-// same prediction value and straddle the top-`k` boundary, all of those
-// classes are considered to be in the top `k`.
-//
-// More formally, let
-//
-// \\(predictions_i\\) be the predictions for all classes for example `i`,
-// \\(targets_i\\) be the target class for example `i`,
-// \\(out_i\\) be the output for example `i`,
-//
-// $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
-//
-// Arguments:
-// predictions: A `batch_size` x `classes` tensor.
-// targets: A `batch_size` vector of class ids.
-// k: Number of top elements to look at for computing precision.
-//
-// Returns Computed precision at `k` as a `bool Tensor`.
-func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Output) (precision tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InTopKV2",
- Input: []tf.Input{
- predictions, targets, k,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Assigns a new value to a variable.
//
// Any ReadVariableOp with a control dependency on this op is guaranteed to return
@@ -19594,6 +19582,130 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
return op.Output(0)
}
+// VarHandleOpAttr is an optional argument to VarHandleOp.
+type VarHandleOpAttr func(optionalAttr)
+
+// VarHandleOpContainer sets the optional container attribute to value.
+//
+// value: the container this variable is placed in.
+// If not specified, defaults to ""
+func VarHandleOpContainer(value string) VarHandleOpAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// VarHandleOpSharedName sets the optional shared_name attribute to value.
+//
+// value: the name by which this variable is referred to.
+// If not specified, defaults to ""
+func VarHandleOpSharedName(value string) VarHandleOpAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Creates a handle to a Variable resource.
+//
+// Arguments:
+// dtype: the type of this variable. Must agree with the dtypes
+// of all ops using this variable.
+// shape: The (possibly partially specified) shape of this variable.
+func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...VarHandleOpAttr) (resource tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype, "shape": shape}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "VarHandleOp",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AngleAttr is an optional argument to Angle.
+type AngleAttr func(optionalAttr)
+
+// AngleTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func AngleTout(value tf.DataType) AngleAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+
+// Returns the argument of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the argument of each element in `input`. All elements in
+// `input` must be complex numbers of the form \\(a + bj\\), where *a*
+// is the real part and *b* is the imaginary part.
+//
+// The argument returned by this operation is of the form \\(atan2(b, a)\\).
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.angle(input) ==> [2.0132, 1.056]
+// ```
+//
+// @compatibility(numpy)
+// Equivalent to np.angle.
+// @end_compatibility
+func Angle(scope *Scope, input tf.Output, optional ...AngleAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Angle",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Clips tensor values to a specified min and max.
+//
+// Given a tensor `t`, this operation returns a tensor of the same type and
+// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
+// Any values less than `clip_value_min` are set to `clip_value_min`. Any values
+// greater than `clip_value_max` are set to `clip_value_max`.
+//
+// Arguments:
+// t: A `Tensor`.
+// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
+// as `t`. The minimum value to clip by.
+// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
+// as `t`. The maximum value to clip by.
+//
+// Returns A clipped `Tensor` with the same shape as input 't'.
+func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ClipByValue",
+ Input: []tf.Input{
+ t, clip_value_min, clip_value_max,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Counts the number of occurrences of each value in an integer array.
//
// Outputs a vector with length `size` and the same dtype as `weights`. If
@@ -26649,56 +26761,6 @@ func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) {
return op.Output(0)
}
-// Returns the batched diagonal part of a batched tensor.
-//
-// This operation returns a tensor with the `diagonal` part
-// of the batched `input`. The `diagonal` part is computed as follows:
-//
-// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
-// tensor of rank `k - 1` with dimensions `[I, J, K, ..., min(M, N)]` where:
-//
-// `diagonal[i, j, k, ..., n] = input[i, j, k, ..., n, n]`.
-//
-// The input must be at least a matrix.
-//
-// For example:
-//
-// ```
-// # 'input' is [[[1, 0, 0, 0]
-// [0, 2, 0, 0]
-// [0, 0, 3, 0]
-// [0, 0, 0, 4]],
-// [[5, 0, 0, 0]
-// [0, 6, 0, 0]
-// [0, 0, 7, 0]
-// [0, 0, 0, 8]]]
-//
-// and input.shape = (2, 4, 4)
-//
-// tf.matrix_diag_part(input) ==> [[1, 2, 3, 4], [5, 6, 7, 8]]
-//
-// which has shape (2, 4)
-// ```
-//
-// Arguments:
-// input: Rank `k` tensor where `k >= 2`.
-//
-// Returns The extracted diagonal(s) having shape
-// `diagonal.shape = input.shape[:-2] + [min(input.shape[-2:])]`.
-func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatrixDiagPart",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the absolute value of a tensor.
//
// Given a tensor `x`, this operation returns a tensor containing the absolute
@@ -30648,65 +30710,3 @@ func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// Returns the element-wise max of two SparseTensors.
-//
-// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
-//
-// Arguments:
-// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
-// SparseTensor, in the canonical lexicographic ordering.
-// a_values: 1-D. `N` non-empty values corresponding to `a_indices`.
-// a_shape: 1-D. Shape of the input SparseTensor.
-// b_indices: counterpart to `a_indices` for the other operand.
-// b_values: counterpart to `a_values` for the other operand; must be of the same dtype.
-// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal.
-//
-// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor.
-func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseSparseMaximum",
- Input: []tf.Input{
- a_indices, a_values, a_shape, b_indices, b_values, b_shape,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Returns a batched matrix tensor with new batched diagonal values.
-//
-// Given `input` and `diagonal`, this operation returns a tensor with the
-// same shape and values as `input`, except for the main diagonal of the
-// innermost matrices. These will be overwritten by the values in `diagonal`.
-//
-// The output is computed as follows:
-//
-// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has
-// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a
-// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
-//
-// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`.
-// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`.
-//
-// Arguments:
-// input: Rank `k+1`, where `k >= 1`.
-// diagonal: Rank `k`, where `k >= 1`.
-//
-// Returns Rank `k+1`, with `output.shape = input.shape`.
-func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatrixSetDiag",
- Input: []tf.Input{
- input, diagonal,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 50d5dcef54..19d2133a55 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -60,9 +60,7 @@ java_library(
filegroup(
name = "java_op_sources",
- srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
- ":java_op_gen_sources",
- ],
+ srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
visibility = [
"//tensorflow/java:__pkg__",
],
@@ -87,6 +85,9 @@ tf_cc_binary(
linkstatic = 1,
deps = [
":java_op_gen_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
"//tensorflow/core:ops",
],
)
@@ -111,6 +112,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_googlesource_code_re2//:re2",
],
)
@@ -303,6 +306,7 @@ tf_cc_test(
],
deps = [
":java_op_gen_lib",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index 62575f6683..f5f54bf4d3 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -26,12 +26,12 @@ namespace java {
// An enumeration of different modifiers commonly used in Java
enum Modifier {
- PACKAGE = 0,
- PUBLIC = (1 << 0),
+ PACKAGE = 0,
+ PUBLIC = (1 << 0),
PROTECTED = (1 << 1),
- PRIVATE = (1 << 2),
- STATIC = (1 << 3),
- FINAL = (1 << 4),
+ PRIVATE = (1 << 2),
+ STATIC = (1 << 3),
+ FINAL = (1 << 4),
};
class Annotation;
@@ -75,12 +75,8 @@ class Type {
// Reflection API does
return Type(Type::PRIMITIVE, "void");
}
- static Type Generic(const string& name) {
- return Type(Type::GENERIC, name);
- }
- static Type Wildcard() {
- return Type(Type::GENERIC, "");
- }
+ static Type Generic(const string& name) { return Type(Type::GENERIC, name); }
+ static Type Wildcard() { return Type(Type::GENERIC, ""); }
static Type Class(const string& name, const string& package = "") {
return Type(Type::CLASS, name, package);
}
@@ -226,9 +222,7 @@ class Method {
// A definition of a documentation bloc for a Java element (JavaDoc)
class Javadoc {
public:
- static Javadoc Create(const string& brief = "") {
- return Javadoc(brief);
- }
+ static Javadoc Create(const string& brief = "") { return Javadoc(brief); }
const string& brief() const { return brief_; }
const string& details() const { return details_; }
Javadoc& details(const string& details) {
diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc
index 6c35cd9595..0d9e0883af 100644
--- a/tensorflow/java/src/gen/cc/op_gen_main.cc
+++ b/tensorflow/java/src/gen/cc/op_gen_main.cc
@@ -41,7 +41,7 @@ const char kUsageHeader[] =
"using an appropriate annotation processor.\n\n"
"The '--base_package' overrides the default parent package under which "
"the generated subpackage and classes are to be located.\n\n"
- "Finally, the `--api_dirs` argument takes a list of comma-seperated "
+ "Finally, the `--api_dirs` argument takes a list of comma-separated "
"directories of API definitions can be provided to override default\n"
"values found in the ops definitions. Directories are ordered by priority "
"(the last having precedence over the first).\n\n";
@@ -55,10 +55,12 @@ int main(int argc, char* argv[]) {
tensorflow::string api_dirs_str;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
- "Root directory into which output files are generated"),
- tensorflow::Flag("base_package", &base_package,
+ "Root directory into which output files are generated"),
+ tensorflow::Flag(
+ "base_package", &base_package,
"Package parent to the generated subpackage and classes"),
- tensorflow::Flag("api_dirs", &api_dirs_str,
+ tensorflow::Flag(
+ "api_dirs", &api_dirs_str,
"List of directories that contains the ops api definitions")};
tensorflow::string usage = tensorflow::java::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index 284f675c94..940390bfcf 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -17,39 +17,43 @@ limitations under the License.
#include <map>
#include <vector>
#include <list>
+#include <map>
#include <memory>
#include <set>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/java/src/gen/cc/java_defs.h"
-#include "tensorflow/java/src/gen/cc/source_writer.h"
#include "tensorflow/java/src/gen/cc/op_generator.h"
#include "tensorflow/java/src/gen/cc/op_specs.h"
+#include "tensorflow/java/src/gen/cc/source_writer.h"
namespace tensorflow {
namespace java {
namespace {
const char* kLicense =
- "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
- "\n"
- "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
- "you may not use this file except in compliance with the License.\n"
- "You may obtain a copy of the License at\n"
- "\n"
- " http://www.apache.org/licenses/LICENSE-2.0\n"
- "\n"
- "Unless required by applicable law or agreed to in writing, software\n"
- "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
- "See the License for the specific language governing permissions and\n"
- "limitations under the License.\n"
- "=======================================================================*/\n";
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ "\n"
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ "you may not use this file except in compliance with the License.\n"
+ "You may obtain a copy of the License at\n"
+ "\n"
+ " http://www.apache.org/licenses/LICENSE-2.0\n"
+ "\n"
+ "Unless required by applicable law or agreed to in writing, software\n"
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ "See the License for the specific language governing permissions and\n"
+ "limitations under the License.\n"
+ "=======================================================================*/"
+ "\n";
// There is three different modes to render an op class, depending on the
// number and type of outputs it has:
@@ -64,20 +68,16 @@ const char* kLicense =
// allowing an instance to be passed directly as a list input to
// another operation
//
-enum RenderMode {
- DEFAULT,
- OPERAND,
- LIST_OPERAND
-};
+enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
void AddArgument(const Variable& var, const string& description,
- Method* method_out, Javadoc* javadoc_out) {
+ Method* method_out, Javadoc* javadoc_out) {
method_out->add_argument(var);
javadoc_out->add_param_tag(var.name(), description);
}
void CollectOpDependencies(const OpSpec& op, RenderMode mode,
- std::list<Type>* out) {
+ std::list<Type>* out) {
out->push_back(Type::Class("Operation", "org.tensorflow"));
out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
out->push_back(Type::Class("Scope", "org.tensorflow.op"));
@@ -110,7 +110,7 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
}
void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
- SourceWriter* writer) {
+ SourceWriter* writer) {
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
if (attr.iterable()) {
string array_name = attr.var().name() + "Array";
@@ -143,11 +143,11 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
}
void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
- SourceWriter* writer) {
+ SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
- Javadoc factory_doc = Javadoc::Create(
- "Factory method to create a class to wrap a new " + op_class.name()
- + " operation to the graph.");
+ Javadoc factory_doc =
+ Javadoc::Create("Factory method to create a class to wrap a new " +
+ op_class.name() + " operation to the graph.");
Variable scope =
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
AddArgument(scope, "current graph scope", &factory, &factory_doc);
@@ -159,23 +159,23 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
}
if (!op.optional_attributes().empty()) {
AddArgument(Variable::Varargs("options", Type::Class("Options")),
- "carries optional attributes values", &factory, &factory_doc);
+ "carries optional attributes values", &factory, &factory_doc);
}
factory_doc.add_tag("return", "a new instance of " + op_class.name());
- writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
- writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
- + op.graph_op_name() + "\", scope.makeOpName(\""
- + op_class.name() + "\"));");
+ writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
+ writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
+ op.graph_op_name() + "\", scope.makeOpName(\"" +
+ op_class.name() + "\"));");
writer->EndLine();
for (const ArgumentSpec& input : op.inputs()) {
if (input.iterable()) {
- writer->Append("opBuilder.addInputList(Operands.asOutputs("
- + input.var().name() + "));");
+ writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
+ input.var().name() + "));");
writer->EndLine();
} else {
- writer->Append("opBuilder.addInput(" + input.var().name()
- + ".asOutput());");
+ writer->Append("opBuilder.addInput(" + input.var().name() +
+ ".asOutput());");
writer->EndLine();
}
}
@@ -200,7 +200,7 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
}
void RenderConstructor(const OpSpec& op, const Type& op_class,
- SourceWriter* writer) {
+ SourceWriter* writer) {
Variable operation =
Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
@@ -214,15 +214,14 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
writer->BeginMethod(constructor, PRIVATE)
.Append("super(operation);")
.EndLine();
- if (op.outputs().size() > 0) {
- writer->Append("int outputIdx = 0;")
- .EndLine();
+ if (!op.outputs().empty()) {
+ writer->Append("int outputIdx = 0;").EndLine();
for (const ArgumentSpec& output : op.outputs()) {
if (output.iterable()) {
string var_length = output.var().name() + "Length";
writer->Append("int " + var_length)
- .Append(" = operation.outputListLength(\"" + output.op_def_name()
- + "\");")
+ .Append(" = operation.outputListLength(\"" + output.op_def_name() +
+ "\");")
.EndLine()
.Append(output.var().name() + " = Arrays.asList(");
if (!output.type().wildcard()) {
@@ -235,8 +234,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
.Append("outputIdx += " + var_length + ";")
.EndLine();
} else {
- writer->Append(output.var().name()
- + " = operation.output(outputIdx++);")
+ writer
+ ->Append(output.var().name() + " = operation.output(outputIdx++);")
.EndLine();
}
}
@@ -246,13 +245,12 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
for (const AttributeSpec& attr : op.optional_attributes()) {
- Method setter =
- Method::Create(attr.var().name(), Type::Class("Options"));
+ Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
Javadoc setter_doc = Javadoc::Create();
AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
- writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
- .Append("return new Options()." + attr.var().name() + "("
- + attr.var().name() + ");")
+ writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
+ .Append("return new Options()." + attr.var().name() + "(" +
+ attr.var().name() + ");")
.EndLine()
.EndMethod();
}
@@ -267,15 +265,16 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
}
void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
- SourceWriter* writer) {
+ SourceWriter* writer) {
ArgumentSpec output = op.outputs().front();
if (mode == OPERAND) {
bool cast2obj = output.type().wildcard();
- Type return_type = Type::Class("Output", "org.tensorflow")
- .add_parameter(cast2obj ? Type::Class("Object") : output.type());
+ Type return_type =
+ Type::Class("Output", "org.tensorflow")
+ .add_parameter(cast2obj ? Type::Class("Object") : output.type());
Method as_output = Method::Create("asOutput", return_type)
- .add_annotation(Annotation::Create("Override"));
+ .add_annotation(Annotation::Create("Override"));
if (cast2obj) {
as_output.add_annotation(
Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
@@ -286,9 +285,7 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
} else {
writer->Append("return ");
}
- writer->Append(output.var().name() + ";")
- .EndLine()
- .EndMethod();
+ writer->Append(output.var().name() + ";").EndLine().EndMethod();
} else if (mode == LIST_OPERAND) {
Type operand = Type::Interface("Operand", "org.tensorflow");
@@ -297,12 +294,13 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
} else {
operand.add_parameter(output.type());
}
- Type return_type = Type::Interface("Iterator", "java.util")
- .add_parameter(operand);
- Method iterator = Method::Create("iterator", return_type)
- .add_annotation(Annotation::Create("Override"))
- .add_annotation(Annotation::Create("SuppressWarnings")
- .attributes("{\"rawtypes\", \"unchecked\"}"));
+ Type return_type =
+ Type::Interface("Iterator", "java.util").add_parameter(operand);
+ Method iterator =
+ Method::Create("iterator", return_type)
+ .add_annotation(Annotation::Create("Override"))
+ .add_annotation(Annotation::Create("SuppressWarnings")
+ .attributes("{\"rawtypes\", \"unchecked\"}"));
// cast the output list using a raw List
writer->BeginMethod(iterator, PUBLIC)
.Append("return (" + return_type.name() + ") ")
@@ -313,10 +311,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
}
void RenderOptionsClass(const OpSpec& op, const Type& op_class,
- SourceWriter* writer) {
+ SourceWriter* writer) {
Type options_class = Type::Class("Options");
- Javadoc options_doc = Javadoc::Create(
- "Optional attributes for {@link " + op_class.canonical_name() + "}");
+ Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
+ op_class.canonical_name() + "}");
writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
for (const AttributeSpec& attr : op.optional_attributes()) {
Method setter = Method::Create(attr.var().name(), options_class);
@@ -339,24 +337,27 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class,
}
inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
- return Type::Class(endpoint.name(),
+ return Type::Class(
+ endpoint.name(),
base_package + "." + str_util::Lowercase(endpoint.package()));
}
void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
- const string& base_package, const string& output_dir, Env* env) {
- Type op_class(ClassOf(endpoint, base_package)
- .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
+ const string& base_package, const string& output_dir,
+ Env* env) {
+ Type op_class(
+ ClassOf(endpoint, base_package)
+ .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
Javadoc op_javadoc(endpoint.javadoc());
// op interfaces
RenderMode mode = DEFAULT;
if (op.outputs().size() == 1) {
const ArgumentSpec& output = op.outputs().front();
- Type operand_type(output.type().wildcard() ?
- Type::Class("Object") : output.type());
+ Type operand_type(output.type().wildcard() ? Type::Class("Object")
+ : output.type());
Type operand_inf(Type::Interface("Operand", "org.tensorflow")
- .add_parameter(operand_type));
+ .add_parameter(operand_type));
if (output.iterable()) {
mode = LIST_OPERAND;
op_class.add_supertype(Type::IterableOf(operand_inf));
@@ -368,10 +369,11 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
// op generic parameters
std::set<string> generics;
for (const ArgumentSpec& output : op.outputs()) {
- if (output.type().kind() == Type::GENERIC && !output.type().wildcard()
- && generics.find(output.type().name()) == generics.end()) {
+ if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
+ generics.find(output.type().name()) == generics.end()) {
op_class.add_parameter(output.type());
- op_javadoc.add_param_tag("<" + output.type().name() + ">",
+ op_javadoc.add_param_tag(
+ "<" + output.type().name() + ">",
"data type for {@code " + output.var().name() + "()} output");
generics.insert(output.type().name());
}
@@ -384,9 +386,10 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
op_class.add_annotation(Annotation::Create("Deprecated"));
string explanation;
if (!op.endpoints().front().deprecated()) {
- explanation = "use {@link " +
- ClassOf(op.endpoints().front(), base_package).canonical_name()
- + "} instead";
+ explanation =
+ "use {@link " +
+ ClassOf(op.endpoints().front(), base_package).canonical_name() +
+ "} instead";
} else {
explanation = op.deprecation_explanation();
}
@@ -396,27 +399,27 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
// expose the op in the Ops Graph API only if it is visible
op_class.add_annotation(
Annotation::Create("Operator", "org.tensorflow.op.annotation")
- .attributes("group = \"" + endpoint.package() + "\""));
+ .attributes("group = \"" + endpoint.package() + "\""));
}
// create op class file
- const string op_dir_name = io::JoinPath(output_dir,
- str_util::StringReplace(op_class.package(), ".", "/", true));
+ const string op_dir_name = io::JoinPath(
+ output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
if (!env->FileExists(op_dir_name).ok()) {
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
<< op_dir_name;
}
const string op_file_name = op_class.name() + ".java";
std::unique_ptr<tensorflow::WritableFile> op_file;
- TF_CHECK_OK(env->NewWritableFile(
- io::JoinPath(op_dir_name, op_file_name), &op_file)) << op_file_name;
+ TF_CHECK_OK(
+ env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
+ << op_file_name;
// render endpoint source code
SourceFileWriter writer(op_file.get());
std::list<Type> dependencies;
CollectOpDependencies(op, mode, &dependencies);
- writer.Write(kLicense)
- .EndLine()
- .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc);
+ writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL,
+ &dependencies, &op_javadoc);
if (!op.optional_attributes().empty()) {
RenderOptionsClass(op, op_class, &writer);
}
@@ -448,7 +451,7 @@ bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
} // namespace
Status OpGenerator::Run(const OpList& op_list, const string& base_package,
- const string& output_dir) {
+ const string& output_dir) {
ApiDefMap api_map(op_list);
if (!api_dirs_.empty()) {
// Only load api files that correspond to the requested "op_list"
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index cfe842070a..759d800ecf 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
-#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/java/src/gen/cc/op_specs.h"
@@ -37,14 +37,15 @@ namespace java {
class OpGenerator {
public:
explicit OpGenerator(const std::vector<string>& api_dirs,
- Env* env = Env::Default()) : api_dirs_(api_dirs), env_(env) {}
+ Env* env = Env::Default())
+ : api_dirs_(api_dirs), env_(env) {}
// Generates wrappers for the given list of 'ops'.
//
// Output files are generated in <output_dir>/<base_package>/<op_package>,
// where 'op_package' is derived from ops endpoints.
Status Run(const OpList& op_list, const string& base_package,
- const string& output_dir);
+ const string& output_dir);
private:
const std::vector<string> api_dirs_;
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index 56806cbb6d..8e5fba7e32 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <string>
#include <algorithm>
#include <list>
+#include <string>
#include "tensorflow/java/src/gen/cc/source_writer.h"
@@ -123,7 +124,7 @@ SourceWriter& SourceWriter::EndBlock() {
}
SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
- const Javadoc* javadoc) {
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
if (!method.constructor()) {
generic_namespace->Visit(method.return_type());
@@ -165,7 +166,8 @@ SourceWriter& SourceWriter::EndMethod() {
}
SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
- const std::list<Type>* extra_dependencies, const Javadoc* javadoc) {
+ const std::list<Type>* extra_dependencies,
+ const Javadoc* javadoc) {
if (!type.package().empty()) {
Append("package ").Append(type.package()).Append(";").EndLine();
}
@@ -186,7 +188,7 @@ SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
}
SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
- const Javadoc* javadoc) {
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
generic_namespace->Visit(type);
EndLine();
@@ -226,7 +228,7 @@ SourceWriter& SourceWriter::EndType() {
}
SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
- const Javadoc* javadoc) {
+ const Javadoc* javadoc) {
// If present, write field javadoc only as one brief line
if (javadoc != nullptr && !javadoc->brief().empty()) {
Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
@@ -345,8 +347,8 @@ void SourceWriter::TypeVisitor::Visit(const Type& type) {
void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
// ignore non-generic parameters, wildcards and generics already declared
- if (type.kind() == Type::GENERIC && !type.wildcard()
- && generic_names_.find(type.name()) == generic_names_.end()) {
+ if (type.kind() == Type::GENERIC && !type.wildcard() &&
+ generic_names_.find(type.name()) == generic_names_.end()) {
declared_types_.push_back(&type);
generic_names_.insert(type.name());
}
diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h
index 1f0febe9a3..de0113bd5b 100644
--- a/tensorflow/java/src/gen/cc/source_writer.h
+++ b/tensorflow/java/src/gen/cc/source_writer.h
@@ -93,7 +93,7 @@ class SourceWriter {
// This method appends a new opening brace to the current data and indent the
// next lines according to Google Java Style Guide. The block can optionally
// be preceded by an expression (e.g. Append("if(true)").BeginBlock();)
- SourceWriter& BeginBlock(const string& expr = "");
+ SourceWriter& BeginBlock(const string& expression = "");
// Ends the current block of source code.
//
@@ -108,7 +108,7 @@ class SourceWriter {
// in parameter to define the access scope of this method and, optionally,
// a Javadoc.
SourceWriter& BeginMethod(const Method& method, int modifiers,
- const Javadoc* javadoc = nullptr);
+ const Javadoc* javadoc = nullptr);
// Ends the current method.
//
@@ -125,9 +125,9 @@ class SourceWriter {
//
// If not null, all types found in the 'extra_dependencies' list will be
// imported before declaring the new type.
- SourceWriter& BeginType(const Type& clazz, int modifiers,
- const std::list<Type>* extra_dependencies = nullptr,
- const Javadoc* javadoc = nullptr);
+ SourceWriter& BeginType(const Type& type, int modifiers,
+ const std::list<Type>* extra_dependencies = nullptr,
+ const Javadoc* javadoc = nullptr);
// Begins to write a new inner type.
//
@@ -136,7 +136,7 @@ class SourceWriter {
// in parameter to define the accesses and the scope of this type and,
// optionally, a Javadoc.
SourceWriter& BeginInnerType(const Type& type, int modifiers,
- const Javadoc* javadoc = nullptr);
+ const Javadoc* javadoc = nullptr);
// Ends the current type.
//
@@ -150,7 +150,7 @@ class SourceWriter {
// or BeginInnerType()). Modifiers are also be passed in parameter to define
// the accesses and the scope of this field and, optionally, a Javadoc.
SourceWriter& WriteField(const Variable& field, int modifiers,
- const Javadoc* javadoc = nullptr);
+ const Javadoc* javadoc = nullptr);
protected:
virtual void DoAppend(const StringPiece& str) = 0;
diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc
index b9a5fee9be..fb8fc64dff 100644
--- a/tensorflow/java/src/gen/cc/source_writer_test.cc
+++ b/tensorflow/java/src/gen/cc/source_writer_test.cc
@@ -245,12 +245,17 @@ TEST(StreamTest, Types) {
SourceBufferWriter writer;
Type generic = Type::Generic("T").add_supertype(Type::Class("Number"));
- writer.AppendType(Type::Int()).Append(", ")
- .AppendType(Type::Class("String")).Append(", ")
- .AppendType(generic).Append(", ")
- .AppendType(Type::ListOf(generic)).Append(", ")
- .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ")
- .AppendType(Type::ListOf(Type::Wildcard()));
+ writer.AppendType(Type::Int())
+ .Append(", ")
+ .AppendType(Type::Class("String"))
+ .Append(", ")
+ .AppendType(generic)
+ .Append(", ")
+ .AppendType(Type::ListOf(generic))
+ .Append(", ")
+ .AppendType(Type::ListOf(Type::IterableOf(generic)))
+ .Append(", ")
+ .AppendType(Type::ListOf(Type::Wildcard()));
const char* expected =
"int, String, T, List<T>, List<Iterable<T>>, List<?>";
@@ -314,7 +319,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
Javadoc clazz_doc = Javadoc::Create("Javadoc test")
- .details("This is a\nmultiline description.");
+ .details("This is a\nmultiline description.");
clazz.add_annotation(Annotation::Create("Bean"));
clazz.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
@@ -380,10 +385,10 @@ TEST(WriteType, ParameterizedClassFields) {
Javadoc field3_doc = Javadoc::Create("This variable is documented");
writer.BeginType(clazz, PUBLIC)
- .WriteField(field1, STATIC | PUBLIC | FINAL)
- .WriteField(field2, PRIVATE)
- .WriteField(field3, PRIVATE, &field3_doc)
- .EndType();
+ .WriteField(field1, STATIC | PUBLIC | FINAL)
+ .WriteField(field2, PRIVATE)
+ .WriteField(field3, PRIVATE, &field3_doc)
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -402,9 +407,9 @@ TEST(WriteType, SimpleInnerClass) {
Type inner_class = Type::Class("InnerTest");
writer.BeginType(clazz, PUBLIC)
- .BeginInnerType(inner_class, PUBLIC)
- .EndType()
- .EndType();
+ .BeginInnerType(inner_class, PUBLIC)
+ .EndType()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -425,9 +430,9 @@ TEST(WriteType, StaticParameterizedInnerClass) {
inner_class.add_parameter(type_t);
writer.BeginType(clazz, PUBLIC)
- .BeginInnerType(inner_class, PUBLIC | STATIC)
- .EndType()
- .EndType();
+ .BeginInnerType(inner_class, PUBLIC | STATIC)
+ .EndType()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -445,8 +450,9 @@ TEST(WriteMethod, SimpleMethod) {
Method method = Method::Create("doNothing", Type::Void());
writer.BeginType(clazz, PUBLIC)
- .BeginMethod(method, PUBLIC).EndMethod()
- .EndType();
+ .BeginMethod(method, PUBLIC)
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -462,15 +468,17 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- Javadoc method_doc = Javadoc::Create("Javadoc test")
- .details("This method has a\nmultiline description.");
+ Javadoc method_doc =
+ Javadoc::Create("Javadoc test")
+ .details("This method has a\nmultiline description.");
method.add_annotation(Annotation::Create("Override"));
method.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
writer.BeginType(clazz, PUBLIC)
- .BeginMethod(method, PUBLIC, &method_doc).EndMethod()
- .EndType();
+ .BeginMethod(method, PUBLIC, &method_doc)
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -497,20 +505,23 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
Method method = Method::Create("boolToInt", Type::Int());
method.add_argument(Variable::Create("b", Type::Boolean()));
method.add_argument(reverse);
- Javadoc method_doc = Javadoc::Create("Converts a boolean to an int")
- .details("This method will convert\na boolean to an int")
- .add_param_tag(reverse.name(), "if true, value is reversed")
- .add_tag("return", "int value for this boolean");
+ Javadoc method_doc =
+ Javadoc::Create("Converts a boolean to an int")
+ .details("This method will convert\na boolean to an int")
+ .add_param_tag(reverse.name(), "if true, value is reversed")
+ .add_tag("return", "int value for this boolean");
writer.BeginType(clazz, PUBLIC)
- .BeginMethod(method, PUBLIC, &method_doc)
- .Append("if (b && !reverse)")
- .BeginBlock()
- .Append("return 1;").EndLine()
- .EndBlock()
- .Append("return 0;").EndLine()
- .EndMethod()
- .EndType();
+ .BeginMethod(method, PUBLIC, &method_doc)
+ .Append("if (b && !reverse)")
+ .BeginBlock()
+ .Append("return 1;")
+ .EndLine()
+ .EndBlock()
+ .Append("return 0;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -543,10 +554,11 @@ TEST(WriteMethod, ParameterizedMethod) {
Method method = Method::Create("doNothing", type_t);
writer.BeginType(clazz, PUBLIC)
- .BeginMethod(method, PUBLIC)
- .Append("return null;").EndLine()
- .EndMethod()
- .EndType();
+ .BeginMethod(method, PUBLIC)
+ .Append("return null;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -567,10 +579,11 @@ TEST(WriteMethod, StaticParameterizedMethod) {
Method method = Method::Create("doNothing", type_t);
writer.BeginType(clazz, PUBLIC)
- .BeginMethod(method, PUBLIC | STATIC)
- .Append("return null;").EndLine()
- .EndMethod()
- .EndType();
+ .BeginMethod(method, PUBLIC | STATIC)
+ .Append("return null;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7201e12c50..679ef93229 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -79,6 +79,7 @@ py_library(
":check_ops",
":client",
":client_testlib",
+ ":collective_ops",
":confusion_matrix",
":control_flow_ops",
":cudnn_rnn_ops_gen",
@@ -255,7 +256,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/py/numpy:headers",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -268,7 +269,7 @@ cc_library(
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -292,7 +293,7 @@ cc_library(
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -315,7 +316,7 @@ cc_library(
":safe_ptr",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -337,7 +338,7 @@ cc_library(
"//tensorflow/core:script_ops_op_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/py/numpy:headers",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -348,7 +349,7 @@ cc_library(
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -378,7 +379,7 @@ cc_library(
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -389,7 +390,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:script_ops_op_lib",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -1436,6 +1437,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "collective_ops_gen",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:collective_ops_op_lib",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "control_flow_ops_gen",
visibility = [
"//learning/brain/python/ops:__pkg__",
@@ -1737,8 +1746,32 @@ py_test(
)
py_library(
+ name = "collective_ops",
+ srcs = ["ops/collective_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":collective_ops_gen",
+ ":framework_for_generated_wrappers",
+ ],
+)
+
+py_test(
+ name = "collective_ops_test",
+ size = "small",
+ srcs = ["ops/collective_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":collective_ops",
+ ":framework_for_generated_wrappers",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "control_flow_grad",
- srcs = ["ops/control_flow_grad.py"],
+ srcs =
+ ["ops/control_flow_grad.py"],
srcs_version = "PY2AND3",
deps = [
":control_flow_ops",
@@ -3404,7 +3437,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/py/numpy:headers",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
@@ -3475,6 +3508,7 @@ tf_py_wrap_cc(
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
+ "//third_party/python_runtime:headers",
"//tensorflow/c:c_api",
"//tensorflow/c:checkpoint_reader",
"//tensorflow/c:python_api",
@@ -3497,7 +3531,6 @@ tf_py_wrap_cc(
"//tensorflow/core/profiler/internal:print_model_analysis",
"//tensorflow/tools/graph_transforms:transform_graph_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
- "//util/python:python_headers",
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps() +
tf_additional_verbs_deps() +
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index 1ddedfda4e..e99f0a203b 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -24,6 +24,7 @@ import zlib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -38,6 +39,13 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
+try:
+ import psutil # pylint: disable=g-import-not-at-top
+ psutil_import_succeeded = True
+except ImportError:
+ psutil_import_succeeded = False
+
+
class TextLineDatasetTest(test.TestCase):
def _lineText(self, f, l):
@@ -162,6 +170,34 @@ class TextLineDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
+ def testIteratorResourceCleanup(self):
+ filename = os.path.join(self.get_temp_dir(), "text.txt")
+ with open(filename, "wt") as f:
+ for i in range(3):
+ f.write("%d\n" % (i,))
+ with context.eager_mode():
+ first_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(first_iterator).numpy())
+ second_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(second_iterator).numpy())
+ # Eager kernel caching is based on op attributes, which includes the
+ # Dataset's output shape. Create a different kernel to test that they
+ # don't create resources with the same names.
+ different_kernel_iterator = iter(
+ readers.TextLineDataset(filename).repeat().batch(16))
+ self.assertEqual([16], next(different_kernel_iterator).shape)
+ # Remove our references to the Python Iterator objects, which (assuming no
+ # reference cycles) is enough to trigger DestroyResourceOp and close the
+ # partially-read files.
+ del first_iterator
+ del second_iterator
+ del different_kernel_iterator
+ if not psutil_import_succeeded:
+ self.skipTest(
+ "psutil is required to check that we've closed our files.")
+ open_files = psutil.Process().open_files()
+ self.assertNotIn(filename, [open_file.path for open_file in open_files])
+
class FixedLengthRecordReaderTest(test.TestCase):
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index fd164277b6..b6dba4e3ca 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -471,9 +471,7 @@ class EagerIterator(object):
sparse.as_dense_types(self._output_types, self._output_classes))
self._flat_output_shapes = nest.flatten(
sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._resource = gen_dataset_ops.iterator(
- shared_name="",
- container=_generate_shared_name("eageriterator"),
+ self._resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 183994ddaa..09062abd74 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -572,6 +572,7 @@ py_test(
":source_utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
@@ -1003,6 +1004,7 @@ cuda_py_test(
"no_oss", # Test flaky due to port collisions.
"no_windows",
"noasan", # Times out due to size of test (b/73731462).
+ "optonly", # Test flaky (b/80130873)
"oss_serial",
],
)
diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py
index 9170046948..a7be20948d 100644
--- a/tensorflow/python/debug/lib/grpc_debug_test_server.py
+++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py
@@ -245,7 +245,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
self._origin_id_to_strings = []
self._graph_tracebacks = []
self._graph_versions = []
- self._source_files = None
+ self._source_files = []
def _initialize_toggle_watch_state(self, toggle_watches):
self._toggle_watches = toggle_watches
@@ -274,7 +274,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
self._origin_id_to_strings = []
self._graph_tracebacks = []
self._graph_versions = []
- self._source_files = None
+ self._source_files = []
def SendTracebacks(self, request, context):
self._call_types.append(request.call_type)
@@ -286,7 +286,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
return debug_service_pb2.EventReply()
def SendSourceFiles(self, request, context):
- self._source_files = request
+ self._source_files.append(request)
return debug_service_pb2.EventReply()
def query_op_traceback(self, op_name):
@@ -351,9 +351,10 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
if not self._source_files:
raise ValueError(
"This debug server has not received any source file contents yet.")
- for source_file_proto in self._source_files.source_files:
- if source_file_proto.file_path == file_path:
- return source_file_proto.lines[lineno - 1]
+ for source_files in self._source_files:
+ for source_file_proto in source_files.source_files:
+ if source_file_proto.file_path == file_path:
+ return source_file_proto.lines[lineno - 1]
raise ValueError(
"Source file at path %s has not been received by the debug server",
file_path)
diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py
index 4b6b2b995e..4afae41bc9 100644
--- a/tensorflow/python/debug/lib/source_remote.py
+++ b/tensorflow/python/debug/lib/source_remote.py
@@ -28,6 +28,7 @@ from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
from tensorflow.python.profiler import tfprof_logger
@@ -95,6 +96,11 @@ def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
return non_tf_files
+def grpc_message_length_bytes():
+ """Maximum gRPC message length in bytes."""
+ return 4 * 1024 * 1024
+
+
def _send_call_tracebacks(destinations,
origin_stack,
is_eager_execution=False,
@@ -155,17 +161,28 @@ def _send_call_tracebacks(destinations,
source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
[call_traceback.origin_stack], call_traceback.origin_id_to_string))
- debugged_source_files = debug_pb2.DebuggedSourceFiles()
+ debugged_source_files = []
for file_path in source_file_paths:
+ source_files = debug_pb2.DebuggedSourceFiles()
_load_debugged_source_file(
- file_path, debugged_source_files.source_files.add())
+ file_path, source_files.source_files.add())
+ debugged_source_files.append(source_files)
for destination in destinations:
channel = grpc.insecure_channel(destination)
stub = debug_service_pb2_grpc.EventListenerStub(channel)
stub.SendTracebacks(call_traceback)
if send_source:
- stub.SendSourceFiles(debugged_source_files)
+ for path, source_files in zip(
+ source_file_paths, debugged_source_files):
+ if source_files.ByteSize() < grpc_message_length_bytes():
+ stub.SendSourceFiles(source_files)
+ else:
+ tf_logging.warn(
+ "The content of the source file at %s is not sent to "
+ "gRPC debug server %s, because the message size exceeds "
+ "gRPC message length limit (%d bytes)." % (
+ path, destination, grpc_message_length_bytes()))
def send_graph_tracebacks(destinations,
diff --git a/tensorflow/python/debug/lib/source_remote_test.py b/tensorflow/python/debug/lib/source_remote_test.py
index 27bafa45e1..29add425e9 100644
--- a/tensorflow/python/debug/lib/source_remote_test.py
+++ b/tensorflow/python/debug/lib/source_remote_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -155,6 +156,51 @@ class SendTracebacksTest(test_util.TensorFlowTestCase):
self.assertEqual(["dummy_run_key"], server.query_call_keys())
self.assertEqual([sess.graph.version], server.query_graph_versions())
+ def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
+ """In case source file size exceeds the grpc message length limit.
+
+ it ought not to have been sent to the server.
+ """
+ this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
+
+ # Patch the method to simulate a very small message length limit.
+ with test.mock.patch.object(
+ source_remote, "grpc_message_length_bytes", return_value=2):
+ with session.Session() as sess:
+ a = variables.Variable(21.0, name="two/a")
+ a_lineno = line_number_above()
+ b = variables.Variable(2.0, name="two/b")
+ b_lineno = line_number_above()
+ x = math_ops.add(a, b, name="two/x")
+ x_lineno = line_number_above()
+
+ send_traceback = traceback.extract_stack()
+ send_lineno = line_number_above()
+ source_remote.send_graph_tracebacks(
+ [self._server_address, self._server_address_2],
+ "dummy_run_key", send_traceback, sess.graph)
+
+ servers = [self._server, self._server_2]
+ for server in servers:
+ # Even though the source file content is not sent, the traceback
+ # should have been sent.
+ tb = server.query_op_traceback("two/a")
+ self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
+ tb = server.query_op_traceback("two/b")
+ self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
+ tb = server.query_op_traceback("two/x")
+ self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
+
+ self.assertIn(
+ (self._curr_file_path, send_lineno, this_func_name),
+ server.query_origin_stack()[-1])
+
+ tf_trace_file_path = (
+ self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
+ # Verify that the source content is not sent to the server.
+ with self.assertRaises(ValueError):
+ self._server.query_source_file_line(tf_trace_file_path, 0)
+
def testSendEagerTracebacksToSingleDebugServer(self):
this_func_name = "testSendEagerTracebacksToSingleDebugServer"
send_traceback = traceback.extract_stack()
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 5530193d4e..dee86966f1 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -26,12 +26,13 @@ cc_library(
"//tensorflow/c/eager:tape",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/python:cpp_python_util",
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python:ndarray_tensor_bridge",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:safe_ptr",
- "//util/python:python_headers",
+ "//third_party/python_runtime:headers",
],
)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index dcfd03b458..b2e6c60021 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import functools
import operator
-import threading
import six
@@ -94,8 +93,8 @@ class _MockOp(object):
)
-def _magic_gradient_function(op_name, attr_tuple, num_inputs,
- inputs, outputs, out_grads):
+def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
+ out_grads):
"""Calls the gradient function of the op.
Args:
@@ -117,8 +116,7 @@ def _magic_gradient_function(op_name, attr_tuple, num_inputs,
return grad_fn(mock_op, *out_grads)
-_gradient_functions = {}
-_gradient_functions_lock = threading.Lock()
+pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
_tracing = False
@@ -142,22 +140,6 @@ _grad_fn_accepts_none_for_indices = {
}
-def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs):
-
- def grad_fn(*orig_outputs):
- result = _magic_gradient_function(op_name, attrs, num_inputs,
- op_inputs, op_outputs, orig_outputs)
- if _tracing:
- print("Gradient for", op_name, "inputs", op_inputs, "output_grads",
- orig_outputs, "gradients", result)
- return nest.flatten(result)
-
- return grad_fn
-
-
-pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn)
-
-
def _record_gradient(op_name, inputs, attrs, results, name):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
results, name)
@@ -225,16 +207,14 @@ def implicit_val_and_grad(f):
f.__name__))
finally:
tape.pop_tape(this_tape)
- # Sorting variables by id, which is monotonically increasing in construction
- # order. This ensures unique order across executions.
- # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc.
- variables = list(sorted(this_tape.watched_variables(),
- key=lambda v: v.handle._id)) # pylint: disable=protected-access
- sources = [x.handle for x in variables]
-
- if not sources:
+ # Note: variables are returned in construction order. This ensures unique
+ # order across executions.
+ variables = this_tape.watched_variables()
+ if not variables:
raise ValueError("No trainable variables were accessed while the "
"function was being computed.")
+
+ sources = [v.handle for v in variables]
grad = imperative_grad.imperative_grad(_default_vspace,
this_tape,
nest.flatten(end_node),
@@ -819,11 +799,8 @@ class GradientTape(object):
self._push_tape()
def watched_variables(self):
- # Sorting variables by id, which is monotonically increasing in construction
- # order. This ensures unique order across executions.
- # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc.
- return list(sorted(self._tape.watched_variables(),
- key=lambda v: v.handle._id)) # pylint: disable=protected-access
+ """Returns variables watched by this tape in order of construction."""
+ return self._tape.watched_variables()
def gradient(self, target, sources, output_gradients=None):
"""Computes the gradient using operations recorded in context of this tape.
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 120b298171..23d87fb394 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -777,7 +777,7 @@ def defun(func=None, compiled=False):
def h():
return f(x, y)
- assert h().numpy() == f(x, y)
+ assert (h().numpy() == f(x, y).numpy()).all()
# `defun` automatically lifts variables out of the graphs it creates,
# allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
@@ -785,6 +785,7 @@ def defun(func=None, compiled=False):
class MyModel(tf.keras.Model):
def __init__(self, keep_probability=0.2):
+ super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.keep_probability = keep_probability
@@ -804,7 +805,7 @@ def defun(func=None, compiled=False):
# `defun`-compiled functions are differentiable.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
with tf.GradientTape() as tape:
- outputs = model(inputs)
+ outputs = model(x)
gradient = tape.gradient(outputs, model.trainable_variables)
optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
model.trainable_variables))
@@ -840,6 +841,8 @@ def defun(func=None, compiled=False):
import tensorflow as tf
import numpy as np
+ tf.enable_eager_execution()
+
matrix = tf.eye(5)
# `matrix` is assumed to be a Tensor
def add_noise():
@@ -862,6 +865,8 @@ def defun(func=None, compiled=False):
```python
import tensorflow as tf
+ tf.enable_eager_execution()
+
@tf.contrib.eager.defun
def lossy_matmul(W, x, training=True):
outputs = tf.matmul(W, x)
@@ -869,6 +874,9 @@ def defun(func=None, compiled=False):
outputs = tf.nn.dropout(outputs, keep_probability=0.2)
return outputs
+ W = tf.random_normal((3, 5))
+ x = tf.random_normal((5, 1))
+
# Executes a graph that applies dropout.
lossy_outputs = lossy_matmul(W, x, training=True)
@@ -919,14 +927,14 @@ def defun(func=None, compiled=False):
# `fn` is a Python function, so x is created, initialized, and destroyed upon
# every invocation
- assert(fn().numpy() == fn().numpy() == 1.0)
+ assert fn().numpy() == fn().numpy() == 1.0
compiled = tf.contrib.eager.defun(fn)
# Compiling `fn` with `defun` hoists all variables outside of the generated
# graph, so initialization happens exactly once.
- assert(compiled().numpy() == 1.0)
- assert(compiled().numpy() == 2.0)
+ assert compiled().numpy() == 1.0
+ assert compiled().numpy() == 2.0
```
Finally, because each input signature is bound to a unique graph, if your
@@ -1207,6 +1215,9 @@ class AutomaticControlDependencies(object):
# test that it works. Support while loops. Support init_scope escaping from
# this.
for op in new_operations:
+ # TODO(apassos) make this code safely support while loops.
+ if isinstance(op._control_flow_context, control_flow_ops.WhileContext): # pylint: disable=protected-access
+ continue
control_inputs = set()
# Ensure stateful ops run
if (op.type not in self._graph._registered_ops # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index a62af4a06c..ea604647fa 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,8 +27,15 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
+// forward declare
+struct EagerTensor;
+
namespace {
+// An instance of _EagerTensorProfiler that will receive callbacks about
+// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
+PyObject* eager_tensor_profiler = nullptr;
+
TFE_Context* GetContext(PyObject* ctx) {
TFE_Context* context =
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
@@ -253,8 +260,45 @@ typedef struct EagerTensor {
// to use a TF_Status object. However note that accesses to `status` are not
// thread-safe.
TF_Status* status;
+
+ PyObject* weakreflist; /* List of weak references */
} EagerTensor;
+namespace {
+
+// Returns true on success - successfully invoked or no profiler registered.
+// Returns false if some error occurred.
+bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
+ if (eager_tensor_profiler != nullptr) {
+#if PY_MAJOR_VERSION < 3
+ PyObject* created_method_name = PyString_InternFromString("created");
+#else
+ PyObject* created_method_name = PyUnicode_InternFromString("created");
+#endif
+ if (created_method_name == nullptr) {
+ return false;
+ }
+ PyObject* result = PyObject_CallMethodObjArgs(
+ eager_tensor_profiler, created_method_name, created_tensor, NULL);
+ if (result == nullptr) {
+ LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
+ // While we can potentially continue because the error is related to
+ // profiling, we choose to return an error because:
+ // - If profiling is used, the user likely wants to stop execution on
+ // profiling errors.
+ // - Error in profiling code might have left some state in an invalid
+ // form that can lead to an error later on. Better to fail fast.
+ Py_DECREF(created_method_name);
+ return false;
+ }
+ Py_DECREF(created_method_name);
+ Py_DECREF(result);
+ }
+ return true;
+}
+
+} // namespace
+
// tp_init for EagerTensor.
int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
self->id = get_uid();
@@ -266,6 +310,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
Py_INCREF(Py_None);
self->tensor_shape = Py_None;
self->status = TF_NewStatus();
+ self->weakreflist = nullptr;
PyObject* value;
PyObject* context = nullptr;
PyObject* device = nullptr;
@@ -344,11 +389,22 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
if (handle == nullptr) return -1;
}
self->handle = handle.release();
+
+ if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
+ return -1;
+ }
+
return 0;
}
// tp_dealloc for EagerTensor.
void EagerTensor_dealloc(EagerTensor* self) {
+ // Clear weak references to self.
+ // Needs to happen before any actual destruction.
+ if (self->weakreflist != nullptr) {
+ PyObject_ClearWeakRefs((PyObject*)self);
+ }
+
TF_DeleteStatus(self->status);
Py_DECREF(self->handle_data);
Py_DECREF(self->keras_mask);
@@ -574,43 +630,43 @@ static PyTypeObject _EagerTensorType = {
// clang-format off
PyVarObject_HEAD_INIT(nullptr, 0)
// clang-format on
- "EagerTensor", /* tp_name */
- sizeof(EagerTensor), /* tp_basicsize */
- 0, /* tp_itemsize */
- (destructor)EagerTensor_dealloc, /* tp_dealloc */
- nullptr, /* tp_print */
- nullptr, /* tp_getattr */
- nullptr, /* tp_setattr */
- nullptr, /* tp_compare */
- nullptr, /* tp_repr */
- nullptr, /* tp_as_number */
- nullptr, /* tp_as_sequence */
- nullptr, /* tp_as_mapping */
- nullptr, /* tp_hash */
- nullptr, /* tp_call */
- nullptr, /* tp_str */
- nullptr, /* tp_getattro */
- nullptr, /* tp_setattro */
- nullptr, /* tp_as_buffer */
- Py_TPFLAGS_DEFAULT, /* tp_flags */
- nullptr, /* tp_doc */
- nullptr, /* tp_traverse */
- nullptr, /* tp_clear */
- nullptr, /* tp_richcompare */
- 0, /* tp_weaklistoffset */
- nullptr, /* tp_iter */
- nullptr, /* tp_iternext */
- EagerTensor_methods, /* tp_methods */
- nullptr, /* tp_members */
- EagerTensor_getseters, /* tp_getset */
- nullptr, /* tp_base */
- nullptr, /* tp_dict */
- nullptr, /* tp_descr_get */
- nullptr, /* tp_descr_set */
- 0, /* tp_dictoffset */
- (initproc)EagerTensor_init, /* tp_init */
- nullptr, /* tp_alloc */
- nullptr, /* tp_new */
+ "EagerTensor", /* tp_name */
+ sizeof(EagerTensor), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ (destructor)EagerTensor_dealloc, /* tp_dealloc */
+ nullptr, /* tp_print */
+ nullptr, /* tp_getattr */
+ nullptr, /* tp_setattr */
+ nullptr, /* tp_compare */
+ nullptr, /* tp_repr */
+ nullptr, /* tp_as_number */
+ nullptr, /* tp_as_sequence */
+ nullptr, /* tp_as_mapping */
+ nullptr, /* tp_hash */
+ nullptr, /* tp_call */
+ nullptr, /* tp_str */
+ nullptr, /* tp_getattro */
+ nullptr, /* tp_setattro */
+ nullptr, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ nullptr, /* tp_doc */
+ nullptr, /* tp_traverse */
+ nullptr, /* tp_clear */
+ nullptr, /* tp_richcompare */
+ offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
+ nullptr, /* tp_iter */
+ nullptr, /* tp_iternext */
+ EagerTensor_methods, /* tp_methods */
+ nullptr, /* tp_members */
+ EagerTensor_getseters, /* tp_getset */
+ nullptr, /* tp_base */
+ nullptr, /* tp_dict */
+ nullptr, /* tp_descr_get */
+ nullptr, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ (initproc)EagerTensor_init, /* tp_init */
+ nullptr, /* tp_alloc */
+ nullptr, /* tp_new */
};
#endif
@@ -641,6 +697,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
t->tensor_shape = Py_None;
t->handle = handle;
t->status = TF_NewStatus();
+ t->weakreflist = nullptr;
+
+ if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
+ return nullptr;
+ }
}
return reinterpret_cast<PyObject*>(t);
}
@@ -720,6 +781,18 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
return reinterpret_cast<PyObject*>(EagerTensorType);
}
+PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
+ Py_XDECREF(eager_tensor_profiler);
+
+ if (profiler == Py_None) {
+ eager_tensor_profiler = nullptr;
+ } else {
+ eager_tensor_profiler = profiler;
+ Py_INCREF(eager_tensor_profiler);
+ }
+ Py_RETURN_NONE;
+}
+
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
PyErr_SetString(PyExc_TypeError,
@@ -792,3 +865,37 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
return EagerTensorFromHandle(handle);
}
+
+PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
+ if (!EagerTensor_CheckExact(tensor)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
+ Py_TYPE(tensor)->tp_name, "\"")
+ .c_str());
+ return nullptr;
+ }
+ TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
+
+ auto status = tensorflow::make_safe(TF_NewStatus());
+ TFE_TensorDebugInfo* debug_info =
+ TFE_TensorHandleTensorDebugInfo(handle, status.get());
+ if (TF_GetCode(status.get()) != TF_OK) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
+ TF_Message(status.get()))
+ .c_str());
+ return nullptr;
+ }
+
+ int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
+ PyObject* shape = PyTuple_New(rank);
+ for (int i = 0; i < rank; ++i) {
+ tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
+ PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
+ }
+ TFE_DeleteTensorDebugInfo(debug_info);
+
+ return shape;
+}
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 9bc8b9bc72..a916a75f00 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -16,10 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
+#include <Python.h>
+
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include <Python.h>
typedef tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>
TFE_InputTensorHandles;
@@ -66,14 +67,15 @@ PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
-// Registers e as the backward_function_getter.
-// The registered function creates a backward function (a function that can
-// return the gradient of the inputs an op given the gradient of it's outputs).
-// The registered function will be passed the following arguments:
-// op_name, attrs, num_inputs, op_inputs, op_outputs
+// Registers e as the gradient_function.
+// The registered function takes
+// (op_name, attrs, num_inputs, inputs, outputs, output_gradients) and returns
+// the input gradients. This function will not correctly be able to generate
+// gradients for functional ops - the gradients for those ops are calculated
+// through a different codepath (see function.py for additional information).
//
// This function is not thread-safe.
-PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e);
+PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);
// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
// `exception` if not nullptr, else using the class registered via
@@ -113,6 +115,15 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
// newly created type, or nullptr on error.
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
+// Sets `profiler` as the current profiler to receive callbacks about events
+// on eager tensors. Currently, the only reported event is creation.
+// `profiler` is expected to have a `created(self, eager_tensor)` method that
+// takes the created tensor as its single argument.
+// Previous profiler, if any, is unset and will not receive any more
+// callbacks.
+// To unset the profiler, pass Py_None as the value of `profiler`.
+PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
+
// Creates a new tape and adds it to the active set. `persistent` must be a
// PyBool_Type, i.e either Py_True or Py_False
PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
@@ -186,7 +197,8 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
PyObject* attrs, PyObject* results,
PyObject* name);
-// Returns the set of variables watched by the given tape.
+// Returns all variables watched by the given tape in the order those variables
+// were created.
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
// Returns an EagerTensor of dimension [len(`tensors`)] containing
@@ -201,4 +213,8 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
// tensors in `tensors`.
PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
+// Returns the shape of this tensor's on-device representation.
+// The shape is represented as a Python tuple of integers.
+PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 62deb41e9b..52b90504f3 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
+#include "tensorflow/python/util/util.h"
using tensorflow::string;
using tensorflow::strings::Printf;
@@ -45,12 +46,14 @@ struct InputInfo {
bool is_list = false;
};
+// Takes in output gradients, returns input gradients.
+typedef std::function<PyObject*(PyObject*)> PyBackwardFunction;
+
using AttrToInputsMap =
tensorflow::gtl::FlatMap<string,
tensorflow::gtl::InlinedVector<InputInfo, 4>>;
-tensorflow::mutex all_attr_to_input_maps_lock(
- tensorflow::LINKER_INITIALIZED);
+tensorflow::mutex all_attr_to_input_maps_lock(tensorflow::LINKER_INITIALIZED);
tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
static auto* all_attr_to_input_maps =
new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
@@ -641,8 +644,8 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr;
// Python subclass of Exception that is created to signal fallback.
PyObject* fallback_exception_class = nullptr;
-// Python function that returns a backward_function.
-PyObject* backward_function_getter = nullptr;
+// Python function that returns input gradients given output gradients.
+PyObject* gradient_function = nullptr;
PyTypeObject* resource_variable_type = nullptr;
@@ -735,26 +738,26 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
}
}
-PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) {
- if (backward_function_getter != nullptr) {
- Py_DECREF(backward_function_getter);
+PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
+ if (gradient_function != nullptr) {
+ Py_DECREF(gradient_function);
}
if (!PyCallable_Check(e)) {
- backward_function_getter = nullptr;
+ gradient_function = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterBackwardFunctionGetter: "
"Registered object should be function.");
return nullptr;
} else {
Py_INCREF(e);
- backward_function_getter = e;
+ gradient_function = e;
Py_RETURN_NONE;
}
}
void RaiseFallbackException(const char* message) {
if (fallback_exception_class != nullptr) {
- PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message));
+ PyErr_SetString(fallback_exception_class, message);
return;
}
@@ -772,8 +775,9 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
if (exception == nullptr) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
- PyErr_SetObject(exception_class,
- Py_BuildValue("si", msg, TF_GetCode(status)));
+ tensorflow::Safe_PyObjectPtr val(
+ Py_BuildValue("si", msg, TF_GetCode(status)));
+ PyErr_SetObject(exception_class, val.get());
return -1;
} else {
exception = PyExc_RuntimeError;
@@ -791,7 +795,8 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
if (exception == nullptr) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
- PyErr_SetObject(exception_class, Py_BuildValue("si", msg, status.code()));
+ tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
+ PyErr_SetObject(exception_class, val.get());
return -1;
} else {
exception = PyExc_RuntimeError;
@@ -868,11 +873,28 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+static tensorflow::int64 FastHandleId(PyObject* variable) {
+ PyObject* handle = PyObject_GetAttrString(variable, "handle");
+ if (handle == nullptr) {
+ return -1;
+ }
+ tensorflow::int64 id = FastTensorId(handle);
+ Py_DECREF(handle);
+ return id;
+}
+
+struct CompareByHandleId {
+ bool operator()(PyObject* lhs, PyObject* rhs) {
+ return FastHandleId(lhs) < FastHandleId(rhs);
+ }
+};
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyObject> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
explicit GradientTape(bool persistent)
- : tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {}
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
+ persistent) {}
virtual ~GradientTape() {
for (PyObject* v : watched_variables_) {
@@ -898,12 +920,12 @@ class GradientTape
}
}
- const std::unordered_set<PyObject*> WatchedVariables() {
+ const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
return watched_variables_;
}
private:
- std::unordered_set<PyObject*> watched_variables_;
+ std::set<PyObject*, CompareByHandleId> watched_variables_;
};
typedef struct {
@@ -1195,11 +1217,13 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
- const std::unordered_set<PyObject*>& watched_variables =
+ const auto& watched_variables =
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
- PyObject* result = PySet_New(nullptr);
+ PyObject* result = PyTuple_New(watched_variables.size());
+ Py_ssize_t pos = 0;
for (PyObject* variable : watched_variables) {
- PySet_Add(result, variable);
+ PyTuple_SET_ITEM(result, pos++, variable);
+ Py_INCREF(variable);
}
return result;
}
@@ -1225,11 +1249,13 @@ void TapeSetRecordOperation(
PyObject* op_type, PyObject* output_tensors,
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
- PyObject* backward_function) {
+ const std::function<PyBackwardFunction*()>& backward_function_getter,
+ const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
+ if (PyErr_Occurred()) return;
output_info.reserve(len);
for (int i = 0; i < len; ++i) {
output_info.push_back(
@@ -1258,10 +1284,10 @@ void TapeSetRecordOperation(
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- Py_INCREF(backward_function);
- tape->tape->RecordOperation(
- op_type_str, output_info, input_ids, input_dtypes, backward_function,
- [backward_function]() { Py_DECREF(backward_function); });
+ auto* function = backward_function_getter();
+ tape->tape->RecordOperation(op_type_str, output_info, input_ids,
+ input_dtypes, function,
+ backward_function_killer);
}
}
} // namespace
@@ -1278,8 +1304,21 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
std::vector<tensorflow::DataType> input_dtypes =
MakeTensorDtypeList(input_tensors);
if (PyErr_Occurred()) return;
- TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes,
- backward_function);
+
+ TapeSetRecordOperation(
+ op_type, output_tensors, input_ids, input_dtypes,
+ [backward_function]() {
+ Py_INCREF(backward_function);
+ PyBackwardFunction* function =
+ new PyBackwardFunction([backward_function](PyObject* out_grads) {
+ return PyObject_CallObject(backward_function, out_grads);
+ });
+ return function;
+ },
+ [backward_function](PyBackwardFunction* py_backward_function) {
+ Py_DECREF(backward_function);
+ delete py_backward_function;
+ });
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
@@ -1288,7 +1327,8 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
+class PyVSpace
+ : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
public:
explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
@@ -1381,7 +1421,7 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
}
tensorflow::Status CallBackwardFunction(
- PyObject* backward_function,
+ PyBackwardFunction* backward_function,
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
std::vector<PyObject*>* result) const final {
PyObject* grads = PyTuple_New(output_gradients.size());
@@ -1394,8 +1434,7 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
reinterpret_cast<PyObject*>(output_gradients[i]));
}
}
- PyObject* py_result = PyEval_CallObject(
- reinterpret_cast<PyObject*>(backward_function), grads);
+ PyObject* py_result = (*backward_function)(grads);
Py_DECREF(grads);
if (py_result == nullptr) {
return tensorflow::errors::Internal("gradient function threw exceptions");
@@ -1424,10 +1463,6 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
return tensorflow::Status::OK();
}
- void ReleaseBackwardFunction(PyObject* backward_function) const final {
- Py_DECREF(backward_function);
- }
-
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
private:
@@ -1586,12 +1621,12 @@ bool CheckInputsOk(PyObject* seq, int start_index,
for (Py_ssize_t j = 0; j < PySequence_Fast_GET_SIZE(item); j++) {
PyObject* inner_item = PySequence_Fast_GET_ITEM(item, j);
if (!CheckOneInput(inner_item)) {
- VLOG(1)
- << "Falling back to slow path for Op \"" << op_def.name()
- << "\", Input \"" << op_def.input_arg(i).name() << "\", Index "
- << j
- << " since we expected an EagerTensor/ResourceVariable, but got "
- << inner_item->ob_type->tp_name;
+ VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
+ << "\", Input \"" << op_def.input_arg(i).name()
+ << "\", Index " << j
+ << " since we expected an EagerTensor/ResourceVariable, "
+ "but got "
+ << inner_item->ob_type->tp_name;
return false;
}
}
@@ -1798,18 +1833,41 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
}
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
- PyObject* callback_args =
- Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs);
-
- PyObject* backward_function =
- PyObject_CallObject(backward_function_getter, callback_args);
- Py_DECREF(callback_args);
- if (backward_function == nullptr) return nullptr;
- TapeSetRecordOperation(op_name, results, input_ids, input_dtypes,
- backward_function);
-
- Py_DECREF(backward_function);
+ TapeSetRecordOperation(
+ op_name, results, input_ids, input_dtypes,
+ [op_name, attrs, num_inputs, op_inputs, op_outputs]() {
+ Py_INCREF(op_name);
+ Py_INCREF(attrs);
+ Py_INCREF(num_inputs);
+ Py_INCREF(op_inputs);
+ Py_INCREF(op_outputs);
+ PyBackwardFunction* function =
+ new PyBackwardFunction([op_name, attrs, num_inputs, op_inputs,
+ op_outputs](PyObject* output_grads) {
+ tensorflow::Safe_PyObjectPtr callback_args(
+ Py_BuildValue("OOOOOO", op_name, attrs, num_inputs, op_inputs,
+ op_outputs, output_grads));
+
+ tensorflow::Safe_PyObjectPtr result(
+ PyObject_CallObject(gradient_function, callback_args.get()));
+
+ if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
+
+ return tensorflow::swig::Flatten(result.get());
+ });
+ return function;
+ },
+ [op_name, attrs, num_inputs, op_inputs,
+ op_outputs](PyBackwardFunction* backward_function) {
+ Py_DECREF(op_name);
+ Py_DECREF(attrs);
+ Py_DECREF(num_inputs);
+ Py_DECREF(op_inputs);
+ Py_DECREF(op_outputs);
+
+ delete backward_function;
+ });
Py_RETURN_NONE;
}
@@ -1880,8 +1938,8 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
// Supports only 2 cases at the moment:
// i) input is an EagerTensor
-// ii) input is a ResourceVariable - in this case, the is_variable param is set
-// to true.
+// ii) input is a ResourceVariable - in this case, the is_variable param is
+// set to true.
//
// NOTE: dtype_hint_getter must *always* return a PyObject that can be
// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index b044b30231..626a4eb1ee 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -292,6 +292,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
def testUnicode(self):
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
+ def testFloatTensor(self):
+ self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
+ self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
+ self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
+
def testSliceDimOutOfRange(self):
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
t2 = _create_tensor([1, 2], dtype=dtypes.int32)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index f09c90bec8..331ee7490e 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -212,8 +212,8 @@ class Estimator(object):
else:
self._session_config = self._config.session_config
- self._device_fn = self._config.device_fn or \
- _get_replica_device_setter(self._config)
+ self._device_fn = (
+ self._config.device_fn or _get_replica_device_setter(self._config))
if model_fn is None:
raise ValueError('model_fn must be provided to Estimator.')
@@ -302,7 +302,7 @@ class Estimator(object):
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
@@ -398,7 +398,7 @@ class Estimator(object):
Args:
input_fn: A function that constructs the input data for evaluation.
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
@@ -477,7 +477,7 @@ class Estimator(object):
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception (`OutOfRangeError` or
`StopIteration`).
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
@@ -564,7 +564,8 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input'
+ '_tf_api_names', '_validate_features_in_predict_input',
+ '_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
@@ -828,10 +829,14 @@ class Estimator(object):
gfile.Rename(temp_export_dir, export_dir)
return export_dir
- def _add_meta_graph_for_mode(
- self, builder, input_receiver_fn_map, checkpoint_path,
- strip_default_attrs, save_variables=True,
- mode=model_fn_lib.ModeKeys.PREDICT):
+ def _add_meta_graph_for_mode(self,
+ builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables=True,
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ export_tags=None):
# pylint: disable=line-too-long
"""Loads variables and adds them along with a MetaGraphDef for saving.
@@ -850,9 +855,14 @@ class Estimator(object):
True for the first call to this function, and the SavedModelBuilder will
raise an error if that is not the case.
mode: tf.estimator.ModeKeys value indicating which mode will be exported.
+ export_tags: The set of tags with which to save `MetaGraphDef`. If None,
+ a default set will be selected to matched the passed mode.
"""
# pylint: enable=line-too-long
+ if export_tags is None:
+ export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
input_receiver_fn = input_receiver_fn_map[mode]
+
with ops.Graph().as_default() as g:
self._create_and_assert_global_step(g)
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -877,8 +887,6 @@ class Estimator(object):
with tf_session.Session(config=self._session_config) as session:
- export_tags = model_fn_lib.EXPORT_TAG_MAP[mode]
-
local_init_op = (
estimator_spec.scaffold.local_init_op or
monitored_session.Scaffold.default_local_init_op())
@@ -1746,10 +1754,19 @@ class WarmStartSettings(
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
- vars_to_warm_start: [Optional] A regular expression that captures which
- variables to warm-start (see tf.get_collection). Defaults to `'.*'`,
- which warm-starts all variables. If `None` is explicitly given, only
- variables specified in `var_name_to_vocab_info` will be warm-started.
+ vars_to_warm_start: [Optional] One of the following:
+
+ - A regular expression (string) that captures which variables to
+ warm-start (see tf.get_collection). This expression will only consider
+ variables in the TRAINABLE_VARIABLES collection.
+ - A list of Variables to warm-start.
+ - A list of strings, each representing a full variable name to warm-start.
+ - `None`, in which case only variables specified in
+ `var_name_to_vocab_info` will be warm-started.
+
+ Defaults to `'.*'`, which warm-starts all variables in the
+ TRAINABLE_VARIABLES collection. Note that this excludes variables such as
+ accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
VocabInfo. The variable names should be "full" variables, not the names
of the partitions. If not explicitly provided, the variable is assumed to
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 1b70189948..a9f20f7fa4 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -814,6 +814,7 @@ class EstimatorTrainTest(test.TestCase):
def test_saving_listeners_are_used(self):
listener = test.mock.Mock(spec=training.CheckpointSaverListener)
+ listener.after_save.return_value = None
est = estimator.Estimator(
model_fn=model_fn_global_step_incrementer,
config=run_config.RunConfig(save_checkpoints_steps=10))
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index b775b19127..a7212bb83e 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -287,11 +287,11 @@ class BestExporter(Exporter):
is_the_final_export):
export_result = None
- if self._model_dir != estimator.model_dir() and self._event_file_pattern:
+ if self._model_dir != estimator.model_dir and self._event_file_pattern:
# Loads best metric from event files.
tf_logging.info('Loading best metric from event files.')
- self._model_dir = estimator.model_dir()
+ self._model_dir = estimator.model_dir
full_event_file_pattern = os.path.join(self._model_dir,
self._event_file_pattern)
self._best_eval_result = self._get_best_eval_result(
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index 053c549071..4cb4bffc8d 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -62,7 +62,7 @@ class BestExporterTest(test.TestCase):
exports_to_keep=5)
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
estimator.export_savedmodel.return_value = "export_result_path"
- estimator.model_dir.return_value = export_dir_base
+ estimator.model_dir = export_dir_base
export_result = exporter.export(estimator, export_dir_base,
"checkpoint_path", {}, False)
@@ -94,7 +94,7 @@ class BestExporterTest(test.TestCase):
exports_to_keep=1)
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
estimator.export_savedmodel.return_value = "export_result_path"
- estimator.model_dir.return_value = export_dir_base
+ estimator.model_dir = export_dir_base
export_result = exporter.export(estimator, export_dir_base,
"checkpoint_path", {"loss": 0.5}, False)
@@ -133,7 +133,7 @@ class BestExporterTest(test.TestCase):
exports_to_keep=1)
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
- estimator.model_dir.return_value = export_dir_base
+ estimator.model_dir = export_dir_base
estimator.export_savedmodel.return_value = "export_result_path"
export_result = exporter.export(estimator, export_dir_base,
@@ -172,7 +172,7 @@ class BestExporterTest(test.TestCase):
serving_input_receiver_fn=_serving_input_receiver_fn,
exports_to_keep=2)
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
- estimator.model_dir.return_value = export_dir_base
+ estimator.model_dir = export_dir_base
# Garbage collect all but the most recent 2 exports,
# where recency is determined based on the timestamp directory names.
exporter.export(estimator, export_dir_base, None, None, False)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index dd92a0403b..2f439f765e 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function
import os
-
+import re
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
@@ -42,10 +42,12 @@ from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import tf_export
+
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -136,8 +138,9 @@ def _in_place_subclassed_model_reset(model):
To "instantiate" an identical model in a new TF graph, we reuse the original
model object, but we clear its state.
- After calling this function on a model intance, you can use the model instance
- as if it were a model clone (in particular you can use it in a new graph).
+ After calling this function on a model instance, you can use the model
+ instance as if it were a model clone (in particular you can use it in a new
+ graph).
This method clears the state of the input model. It is thus destructive.
However the original state can be restored fully by calling
@@ -220,7 +223,6 @@ def _in_place_subclassed_model_reset(model):
for name in attributes_to_cache:
attributes_cache[name] = getattr(model, name)
model._original_attributes_cache = attributes_cache
-
# Reset built state
model.built = False
model.inputs = None
@@ -340,8 +342,19 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
"""model_fn for keras Estimator."""
model = _clone_and_build_model(mode, keras_model, custom_objects, features,
labels)
+ model_output_names = []
+ # We need to make sure that the output names of the last layer in the model
+ # is the same for each of the cloned models. This is required for mirrored
+ # strategy when we call regroup.
+ if distribute_lib.has_distribution_strategy():
+ for name in model.output_names:
+ name = re.compile(r'_\d$').sub('', name)
+ model_output_names.append(name)
+ else:
+ model_output_names = model.output_names
+
# Get inputs to EstimatorSpec
- predictions = dict(zip(model.output_names, model.outputs))
+ predictions = dict(zip(model_output_names, model.outputs))
loss = None
train_op = None
@@ -445,10 +458,14 @@ def model_to_estimator(keras_model=None,
@{$programmers_guide/estimators$creating_estimators_from_keras_models}.
Args:
- keras_model: Keras model in memory.
- keras_model_path: Directory to a keras model on disk.
+ keras_model: A compiled Keras model object. This argument is mutually
+ exclusive with `keras_model_path`.
+ keras_model_path: Path to a compiled Keras model saved on disk, in HDF5
+ format, which can be generated with the `save()` method of a Keras model.
+ This argument is mutually exclusive with `keras_model`.
custom_objects: Dictionary for custom objects.
- model_dir: Directory to save Estimator model parameters, graph and etc.
+ model_dir: Directory to save Estimator model parameters, graph, summary
+ files for TensorBoard, etc.
config: Configuration object.
Returns:
@@ -460,7 +477,7 @@ def model_to_estimator(keras_model=None,
ValueError: if the keras_model_path is a GCS URI.
ValueError: if keras_model has not been compiled.
"""
- if (not keras_model) and (not keras_model_path):
+ if not (keras_model or keras_model_path):
raise ValueError(
'Either `keras_model` or `keras_model_path` needs to be provided.')
if keras_model and keras_model_path:
@@ -482,8 +499,9 @@ def model_to_estimator(keras_model=None,
if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:
raise ValueError(
- 'The given keras model has not been compiled yet. Please compile first '
- 'before calling `model_to_estimator`.')
+ 'The given keras model has not been compiled yet. '
+ 'Please compile the model with `model.compile()` '
+ 'before calling `model_to_estimator()`.')
if isinstance(config, dict):
config = run_config_lib.RunConfig(**config)
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index a187b12b24..522662cd32 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -129,7 +129,7 @@ class TrainSpec(
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
@@ -193,7 +193,7 @@ class EvalSpec(
Args:
input_fn: A function that constructs the input data for evaluation.
- See @{$get_started/premade_estimators#create_input_functions} for more
+ See @{$premade_estimators#create_input_functions} for more
information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
@@ -444,7 +444,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'For distributed training, there can only be one `evaluator` task '
'(with task id 0). Given task id {}'.format(config.task_id))
- executor.run()
+ return executor.run()
class _StopAtSecsHook(session_run_hook.SessionRunHook):
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 7f9ef53457..c3f70df7d8 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -120,11 +120,7 @@ class DType(object):
@property
def is_numpy_compatible(self):
- numpy_incompatible = [
- types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
- types_pb2.DT_RESOURCE_REF
- ]
- return self._type_enum not in numpy_incompatible
+ return self._type_enum not in _NUMPY_INCOMPATIBLE
@property
def as_numpy_dtype(self):
@@ -162,7 +158,7 @@ class DType(object):
@property
def is_quantized(self):
"""Returns whether this is a quantized data type."""
- return self.base_dtype in [qint8, quint8, qint16, quint16, qint32]
+ return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
@property
def is_unsigned(self):
@@ -401,6 +397,11 @@ quint16_ref = DType(types_pb2.DT_QUINT16_REF)
qint32_ref = DType(types_pb2.DT_QINT32_REF)
bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
+_NUMPY_INCOMPATIBLE = frozenset([
+ types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
+ types_pb2.DT_RESOURCE_REF
+])
+
# Maintain an intern table so that we don't have to create a large
# number of small objects.
_INTERN_TABLE = {
@@ -645,10 +646,10 @@ _TF_TO_NP = {
_np_bfloat16,
}
-QUANTIZED_DTYPES = frozenset([
- qint8, quint8, qint16, quint16, qint32, qint8_ref, quint8_ref, qint16_ref,
- quint16_ref, qint32_ref
-])
+_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
+_QUANTIZED_DTYPES_REF = frozenset(
+ [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref])
+QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
tf_export("QUANTIZED_DTYPES").export_constant(__name__, "QUANTIZED_DTYPES")
_PYTHON_TO_TF = {
@@ -662,10 +663,9 @@ def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
Args:
- type_value: A value that can be converted to a `tf.DType`
- object. This may currently be a `tf.DType` object, a
- [`DataType`
- enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
+ type_value: A value that can be converted to a `tf.DType` object. This may
+ currently be a `tf.DType` object, a [`DataType`
+ enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
a string type name, or a `numpy.dtype`.
Returns:
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 7fb3a22f5a..6b031fe99b 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import functools
import linecache
import os
import re
@@ -3862,6 +3861,9 @@ class Graph(object):
assert c.graph is g
```
+ If eager execution is enabled ops created under this context manager will be
+ added to the graph instead of executed eagerly.
+
Returns:
A context manager for using this graph as the default graph.
"""
@@ -5278,35 +5280,15 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
@tf_contextlib.contextmanager
def get_controller(self, default):
try:
- if context.executing_eagerly():
- # A Graph alone on the context stack would keep init_scope-wrapped
- # operations graph building when entered (assuming init_scope is called
- # in a graph building context). Instead, we push a context which first
- # enables eager execution and then re-enters the Graph.
- context.context().context_switches.push(
- default.building_function,
- functools.partial(
- _enter_context_and_graph,
- context.eager_mode,
- default.as_default))
- else:
- # This Graph is being used from a graph building context. A lack of
- # context switch implies that the context is graph building.
- context.context().context_switches.push(default.building_function,
- default.as_default)
- with super(_DefaultGraphStack, self).get_controller(default) as g:
+ context.context().context_switches.push(
+ default.building_function, default.as_default)
+ with super(_DefaultGraphStack, self).get_controller(
+ default) as g, context.graph_mode():
yield g
finally:
context.context().context_switches.pop()
-@tf_contextlib.contextmanager
-def _enter_context_and_graph(context_fn, graph_fn):
- """Combines two context managers."""
- with context_fn(), graph_fn():
- yield
-
-
_default_graph_stack = _DefaultGraphStack()
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index f8501b27ae..b3bc800fee 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2224,12 +2224,25 @@ class InitScopeTest(test_util.TensorFlowTestCase):
self.assertEqual(ops.get_name_scope(), "inner")
self.assertEqual(ops.get_name_scope(), "")
- def testEagerGraphContextsExecuteEagerly(self):
+ def testEnteringGraphFromEagerIsSticky(self):
with context.eager_mode():
+ g = ops.Graph()
+ with g.as_default():
+ with ops.init_scope():
+ self.assertFalse(context.executing_eagerly())
+ self.assertEqual(g, ops.get_default_graph())
+
+ def testMixGraphEager(self):
+ with context.eager_mode():
+ c = constant_op.constant(1.0)
with ops.Graph().as_default():
- with context.graph_mode():
- with ops.init_scope():
- self.assertTrue(context.executing_eagerly())
+ with self.assertRaisesRegexp(
+ RuntimeError, "Attempting to capture an EagerTensor"):
+ math_ops.add(c, c)
+ c2 = constant_op.constant(2.0)
+ with self.assertRaisesRegexp(
+ TypeError, "contains objects other than 'EagerTensor'"):
+ math_ops.add(c2, c2)
def testPreservesNameScopeInEagerExecution(self):
with context.eager_mode():
@@ -2263,6 +2276,11 @@ class GraphTest(test_util.TensorFlowTestCase):
with g0.as_default():
ops.reset_default_graph()
+ def testGraphContextManagerCancelsEager(self):
+ with context.eager_mode():
+ with ops.Graph().as_default():
+ self.assertFalse(context.executing_eagerly())
+
def testGraphContextManager(self):
g0 = ops.Graph()
with g0.as_default() as g1:
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 0dd29460ed..c9be3d5005 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -961,9 +961,12 @@ def unknown_shape(ndims=None):
return TensorShape([Dimension(None)] * ndims)
+_SCALAR_SHAPE = TensorShape([])
+
+
def scalar():
"""Returns a shape representing a scalar."""
- return TensorShape([])
+ return _SCALAR_SHAPE
def vector(length):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 5b01df48fe..b56483f373 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -556,12 +556,16 @@ def assert_no_new_tensors(f):
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
- outside_graph_key = ops.get_default_graph()._graph_key
- with ops.Graph().as_default():
+ if context.executing_eagerly():
+ f(self, **kwargs)
+ ops.reset_default_graph()
+ else:
# Run the test in a new graph so that collections get cleared when it's
# done, but inherit the graph key so optimizers behave.
- ops.get_default_graph()._graph_key = outside_graph_key
- f(self, **kwargs)
+ outside_graph_key = ops.get_default_graph()._graph_key
+ with ops.Graph().as_default():
+ ops.get_default_graph()._graph_key = outside_graph_key
+ f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
backprop._zeros_cache.flush()
@@ -727,12 +731,12 @@ def run_in_graph_and_eager_modes(__unused__=None,
f(self, **kwargs)
if assert_no_eager_garbage:
+ ops.reset_default_graph()
run_eagerly = assert_no_new_tensors(
assert_no_garbage_created(run_eagerly))
with context.eager_mode():
- with ops.Graph().as_default():
- run_eagerly(self, **kwargs)
+ run_eagerly(self, **kwargs)
return decorated
@@ -1027,7 +1031,9 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
return config
- if graph is None:
+ if context.executing_eagerly():
+ yield None
+ elif graph is None:
if self._cached_session is None:
self._cached_session = session.Session(
graph=None, config=prepare_config(config))
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 0f53762f6f..0178908bcc 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -619,6 +619,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
ReferenceCycleTest().test_has_no_cycle()
+ @test_util.run_in_graph_and_eager_modes()
def test_no_leaked_tensor_decorator(self):
class LeakedTensorTest(object):
@@ -628,11 +629,11 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
@test_util.assert_no_new_tensors
def test_has_leak(self):
- self.a = constant_op.constant([3.])
+ self.a = constant_op.constant([3.], name="leak")
@test_util.assert_no_new_tensors
def test_has_no_leak(self):
- constant_op.constant([3.])
+ constant_op.constant([3.], name="no-leak")
with self.assertRaisesRegexp(AssertionError, "Tensors not deallocated"):
LeakedTensorTest().test_has_leak()
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 5d730695b9..fe40c9fbed 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -135,6 +135,7 @@ py_library(
deps = [
":backend",
"//tensorflow/python/data",
+ "//tensorflow/python/training/checkpointable:data_structures_base",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 552e7b7d73..a6b5940e2f 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -41,6 +41,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures_base
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -321,7 +322,10 @@ class Network(base_layer.Layer):
no_dependency = isinstance(value, checkpointable.NoDependency)
if no_dependency:
value = value.value
- if isinstance(value, (base_layer.Layer, Network)):
+ if isinstance(value, (
+ base_layer.Layer,
+ Network,
+ data_structures_base.CheckpointableDataStructureBase)):
try:
is_graph_network = self._is_graph_network
except AttributeError:
@@ -1424,7 +1428,15 @@ class Network(base_layer.Layer):
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
+
+ Raises:
+ ValueError: if `summary()` is called before the model is built.
"""
+ if not self.built:
+ raise ValueError('This model has never been called, thus its weights '
+ 'have not yet been created, so no summary can be '
+ 'displayed. Build the model first '
+ '(e.g. by calling it on some data).')
print_layer_summary(self,
line_length=line_length,
positions=positions,
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index ff50d0b6e2..6d625f16c2 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -112,6 +112,8 @@ class Model(Network):
super(Model, self).__init__(*args, **kwargs)
# Create a cache for iterator get_next op.
self._iterator_get_next = weakref.WeakKeyDictionary()
+ # Create a cache for dataset - uninitialized iterators
+ self._dataset_iterator_cache = weakref.WeakKeyDictionary()
def compile(self,
optimizer,
@@ -670,12 +672,12 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset iterator,
- `y` should not be specified
+ tensor targets, or inversely). If `x` is a dataset or a
+ dataset iterator, `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: An optional sample-weight array passed by the user to
weight the importance of each sample in `x`.
@@ -706,11 +708,16 @@ class Model(Network):
RuntimeError: If the model was never compiled.
"""
if isinstance(x, dataset_ops.Dataset):
- raise ValueError('You passed a `Dataset` instance to your model (%s), '
- 'which is not supported. Instead, pass an `Iterator`, '
- 'which you can obtain e.g. via '
- '`dataset.make_one_shot_iterator()` (the exact method '
- 'to use will depend on your specific dataset).' % x)
+ if context.executing_eagerly():
+ x = x.make_one_shot_iterator()
+ else:
+ if x in self._dataset_iterator_cache:
+ x = self._dataset_iterator_cache[x]
+ else:
+ iterator = x.make_initializable_iterator()
+ self._dataset_iterator_cache[x] = iterator
+ x = iterator
+ K.get_session().run(x.initializer)
# Validates `steps` argument based on x's type.
if check_steps:
@@ -719,7 +726,7 @@ class Model(Network):
is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator)
is_x_iterator = isinstance(x, iterator_ops.Iterator)
- # Validate user inputs when data is given as a dataset iterator.
+ # Validate user inputs when data is given as a dataset or dataset iterator.
if is_x_iterator or is_x_eager_iterator:
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
@@ -839,7 +846,8 @@ class Model(Network):
# in the case where all inputs are value arrays.
if context.executing_eagerly():
- # In eager mode, do not do shape validation.
+ # In eager mode, do not do shape validation
+ # since the network has no input nodes (placeholders) to be fed.
feed_input_names = self.input_names
feed_input_shapes = None
elif not self._is_graph_network:
@@ -1130,19 +1138,19 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset iterator,
- `y` should not be specified
+ tensor targets, or inversely). If `x` is a dataset or dataset
+ iterator, `y` should not be specified
(since targets will be obtained from the iterator).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
- form of symbolic tensors or dataset iterators (since they generate
- batches).
+ form of symbolic tensors, datasets, or dataset iterators
+ (since they generate batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1164,7 +1172,7 @@ class Model(Network):
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
- not supported when `x` is a dataset iterator.
+ not supported when `x` is a dataset or a dataset iterator.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
@@ -1172,7 +1180,7 @@ class Model(Network):
`validation_data` could be:
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- - dataset iterator
+ - dataset or a dataset iterator
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
@@ -1195,7 +1203,7 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1252,7 +1260,8 @@ class Model(Network):
# Prepare validation data.
if validation_data:
if (isinstance(validation_data, iterator_ops.Iterator) or
- isinstance(validation_data, iterator_ops.EagerIterator)):
+ isinstance(validation_data, iterator_ops.EagerIterator) or
+ isinstance(validation_data, dataset_ops.Dataset)):
val_x = validation_data
val_y = None
val_sample_weight = None
@@ -1266,8 +1275,9 @@ class Model(Network):
'When passing a `validation_data` argument, '
'it must contain either 2 items (x_val, y_val), '
'or 3 items (x_val, y_val, val_sample_weights), '
- 'or alternatively it could be a dataset iterator. However we '
- 'received `validation_data=%s`' % validation_data)
+ 'or alternatively it could be a dataset or a '
+ 'dataset or a dataset iterator. '
+ 'However we received `validation_data=%s`' % validation_data)
# Validate and standardize validation data.
val_x, val_y, val_sample_weights = self._standardize_user_data(
@@ -1351,19 +1361,19 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset iterator,
- `y` should not be specified
- (since targets will be obtained from the iterator).
+ tensor targets, or inversely).
+ If `x` is a dataset or a dataset iterator, `y` should not be specified
+ (since targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors or dataset iterators (since they generate
- batches).
+ form of symbolic tensors, datasets, or dataset iterators
+ (since they generate batches).
verbose: 0 or 1. Verbosity mode.
0 = silent, 1 = progress bar.
sample_weight: Optional Numpy array of weights for
@@ -1377,7 +1387,7 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
@@ -1426,13 +1436,13 @@ class Model(Network):
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors or dataset iterators (since they generate
- batches).
+ form of symbolic tensors, dataset, or dataset iterators
+ (since they generate batches).
verbose: Verbosity mode, 0 or 1.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
@@ -1473,12 +1483,12 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset iterator,
- `y` should not be specified
+ tensor targets, or inversely). If `x` is a dataset or a
+ dataset iterator, `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
@@ -1487,8 +1497,7 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
- supported when `x` is a dataset iterator.
-
+ supported when `x` is a dataset or a dataset iterator.
class_weight: Optional dictionary mapping
class indices (integers) to
a weight (float) to apply to the model's loss for the samples
@@ -1537,12 +1546,12 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset iterator,
- `y` should not be specified
+ tensor targets, or inversely). If `x` is a dataset or a
+ dataset iterator, `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
@@ -1551,7 +1560,7 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
- supported when `x` is a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1590,7 +1599,7 @@ class Model(Network):
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- - A `tf.data` dataset iterator.
+ - A `tf.data` dataset or a dataset iterator.
Returns:
Numpy array(s) of predictions.
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 7dec0bbf8a..5c02d36382 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1742,7 +1742,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
# Test with validation split
with self.assertRaisesRegexp(
ValueError, '`validation_split` argument is not supported '
- 'when input `x` is a dataset iterator'):
+ 'when input `x` is a dataset or a dataset iterator'):
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
validation_split=0.5, validation_steps=2)
@@ -1751,7 +1751,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
ValueError, '`sample_weight` argument is not supported '
- 'when input `x` is a dataset iterator'):
+ 'when input `x` is a dataset or a dataset iterator'):
model.fit(
iterator,
epochs=1,
@@ -1761,10 +1761,6 @@ class TestTrainingWithDatasetIterators(test.TestCase):
# Test invalid usage
with self.assertRaisesRegexp(ValueError,
- 'Instead, pass an `Iterator`'):
- model.fit(dataset,
- epochs=1, steps_per_epoch=2, verbose=0)
- with self.assertRaisesRegexp(ValueError,
'you should not specify a target'):
model.fit(iterator, iterator,
epochs=1, steps_per_epoch=2, verbose=0)
@@ -1829,5 +1825,129 @@ class TestTrainingWithDatasetIterators(test.TestCase):
'dataset iterator ran out of data')
+class TestTrainingWithDataset(test.TestCase):
+
+ def test_calling_model_on_same_dataset(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ # Call fit with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+ # Finalize the graph to make sure new ops aren't added when calling on the
+ # same dataset
+ ops.get_default_graph().finalize()
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_training_and_eval_methods_on_dataset(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
+
+ # Test with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(dataset,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_split=0.5, validation_steps=2)
+
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, '`sample_weight` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
+
+ # Test invalid usage
+ with self.assertRaisesRegexp(ValueError,
+ 'you should not specify a target'):
+ model.fit(dataset, dataset,
+ epochs=1, steps_per_epoch=2, verbose=0)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(dataset, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(dataset, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(dataset, verbose=0)
+
+ def test_dataset_input_shape_validation(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss)
+
+ # User forgets to batch the dataset
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have 2 dimensions'):
+ model.train_on_batch(dataset)
+
+ # Wrong input shape
+ inputs = np.zeros((10, 5), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have shape'):
+ model.train_on_batch(dataset)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 7d214d61a4..b93f999444 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -166,10 +166,16 @@ def standardize_input_data(data,
# Check shapes compatibility.
if shapes:
for i in range(len(names)):
- if shapes[i] is not None and not tensor_util.is_tensor(data[i]):
- data_shape = data[i].shape
+ if shapes[i] is not None:
+ if tensor_util.is_tensor(data[i]):
+ tensorshape = data[i].get_shape()
+ if not tensorshape:
+ continue
+ data_shape = tuple(tensorshape.as_list())
+ else:
+ data_shape = data[i].shape
shape = shapes[i]
- if data[i].ndim != len(shape):
+ if len(data_shape) != len(shape):
raise ValueError('Error when checking ' + exception_prefix +
': expected ' + names[i] + ' to have ' +
str(len(shape)) + ' dimensions, but got array '
@@ -178,7 +184,7 @@ def standardize_input_data(data,
data_shape = data_shape[1:]
shape = shape[1:]
for dim, ref_dim in zip(data_shape, shape):
- if ref_dim != dim and ref_dim:
+ if ref_dim != dim and ref_dim is not None and dim is not None:
raise ValueError(
'Error when checking ' + exception_prefix + ': expected ' +
names[i] + ' to have shape ' + str(shape) +
@@ -632,19 +638,20 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
provided by user.
"""
if y is not None:
- raise ValueError('You passed a dataset iterator (%s) as input `x` to '
- 'your model. In that case, you should not specify '
- 'a target (`y`) argument, since the dataset iterator '
- 'generates both input data and target data. '
+ raise ValueError('You passed a dataset or dataset iterator (%s) as '
+ 'input `x` to your model. In that case, you should '
+ 'not specify a target (`y`) argument, since the dataset '
+ 'or dataset iterator generates both input data and '
+ 'target data. '
'Received: %s' % (x, y))
if sample_weight is not None:
- raise ValueError('`sample_weight` argument is not supported when input'
- ' `x` is a dataset iterator. '
+ raise ValueError('`sample_weight` argument is not supported when input '
+ '`x` is a dataset or a dataset iterator. '
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
if validation_split is not None and validation_split != 0.0:
raise ValueError(
'`validation_split` argument is not supported when '
- 'input `x` is a dataset iterator. '
+ 'input `x` is a dataset or a dataset iterator. '
'Received: x=%s, validation_split=%f' % (x, validation_split))
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index 5c4a2dbe92..ad6594279d 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
+from tensorflow.python.framework import constant_op
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@@ -71,10 +72,11 @@ class _CuDNNRNN(RNN):
self.constants_spec = None
self._states = None
self._num_constants = None
+ self._vector_shape = constant_op.constant([-1])
def _canonical_to_params(self, weights, biases):
- weights = [array_ops.reshape(x, (-1,)) for x in weights]
- biases = [array_ops.reshape(x, (-1,)) for x in biases]
+ weights = [array_ops.reshape(x, self._vector_shape) for x in weights]
+ biases = [array_ops.reshape(x, self._vector_shape) for x in biases]
return array_ops.concat(weights + biases, axis=0)
def call(self, inputs, mask=None, training=None, initial_state=None):
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 83b353600a..3dfad9c130 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2334,6 +2334,9 @@ cuda_py_test(
"//tensorflow/python:nn_ops",
],
shard_count = 2,
+ tags = [
+ "no_gpu", # Flaky: b/80127739
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py
index dc11b7dece..5f516f2c7e 100644
--- a/tensorflow/python/kernel_tests/accumulate_n_eager_test.py
+++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py
@@ -43,10 +43,9 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase):
np.random.seed(12345)
x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)]
tf_x = ops.convert_n_to_tensor(x)
- with self.test_session(use_gpu=True):
- self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy())
- self.assertAllClose(x[0] * 5,
- math_ops.accumulate_n([tf_x[0]] * 5).numpy())
+ self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x))
+ self.assertAllClose(x[0] * 5,
+ math_ops.accumulate_n([tf_x[0]] * 5))
def testGrad(self):
np.random.seed(42)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
index 30e6289420..4f92ab0795 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/BUILD
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -52,7 +52,7 @@ tf_py_test(
tf_py_test(
name = "stats_ops_test",
- size = "small",
+ size = "medium",
srcs = ["stats_ops_test.py"],
additional_deps = [
"//tensorflow/python:boosted_trees_ops",
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 5cceb98cff..568e695fd5 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.platform import googletest
@@ -388,6 +391,41 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
],
result.eval())
+ def _verify_precision(self, length):
+ with self.test_session():
+ max_splits = 1
+ num_buckets = 1
+ node_ids = array_ops.fill([length], 0)
+
+ gradients = constant_op.constant(
+ 2.0 / length, dtype=dtypes.float32, shape=[length, 1])
+ hessians = constant_op.constant(
+ 0.2 / length, dtype=dtypes.float32, shape=[length, 1])
+
+ bucketized_features = array_ops.zeros([length], dtype=dtypes.int32)
+
+ result = boosted_trees_ops.make_stats_summary(
+ node_ids, gradients, hessians, [bucketized_features], max_splits,
+ num_buckets) # shape=[max_splits, num_buckets, num_features, 2]
+
+ self.assertAllClose([[[[2., 0.2]]]], result.eval())
+
+ def testMakeStatsSummaryNumericalPrecisionSmallBatch(self):
+ """Tests numeric precision."""
+ self._verify_precision(length=2000)
+
+ def testMakeStatsSummaryNumericalPrecisionMediumBatch(self):
+ """Tests numeric precision."""
+ self._verify_precision(length=100000)
+
+ def testMakeStatsSummaryNumericalPrecisionLargeBatch(self):
+ """Tests numeric precision."""
+ self._verify_precision(length=1000000)
+
+ def testMakeStatsSummaryNumericalPrecisionMegaBatch(self):
+ """Tests numeric precision."""
+ self._verify_precision(length=50000000)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index a7fe336e6a..8b11556330 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -90,9 +90,10 @@ class IntentionallyMissingError(Exception):
class BrokenBijector(bijector.Bijector):
"""Forward and inverse are not inverses of each other."""
- def __init__(self, forward_missing=False, inverse_missing=False):
+ def __init__(
+ self, forward_missing=False, inverse_missing=False, validate_args=False):
super(BrokenBijector, self).__init__(
- validate_args=False, forward_min_event_ndims=0, name="broken")
+ validate_args=validate_args, forward_min_event_ndims=0, name="broken")
self._forward_missing = forward_missing
self._inverse_missing = inverse_missing
@@ -116,6 +117,33 @@ class BrokenBijector(bijector.Bijector):
raise IntentionallyMissingError
return math_ops.log(2.)
+class BijectorTestEventNdims(test.TestCase):
+
+ def testBijectorNonIntegerEventNdims(self):
+ bij = BrokenBijector()
+ with self.assertRaisesRegexp(ValueError, "Expected integer"):
+ bij.forward_log_det_jacobian(1., event_ndims=1.5)
+ with self.assertRaisesRegexp(ValueError, "Expected integer"):
+ bij.inverse_log_det_jacobian(1., event_ndims=1.5)
+
+ def testBijectorArrayEventNdims(self):
+ bij = BrokenBijector()
+ with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+ bij.forward_log_det_jacobian(1., event_ndims=(1, 2))
+ with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+ bij.inverse_log_det_jacobian(1., event_ndims=(1, 2))
+
+ def testBijectorDynamicEventNdims(self):
+ bij = BrokenBijector(validate_args=True)
+ event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
+ with self.test_session():
+ with self.assertRaisesOpError("Expected scalar"):
+ bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
+ event_ndims: (1, 2)})
+ with self.assertRaisesOpError("Expected scalar"):
+ bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({
+ event_ndims: (1, 2)})
+
@six.add_metaclass(abc.ABCMeta)
class BijectorCachingTestBase(object):
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index c89994591c..b59e3dd7e7 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -463,9 +463,8 @@ class PyFuncTest(test.TestCase):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
- with self.test_session():
- ret = self.evaluate(output)
- self.assertAllEqual(ret, [[3], [3], [3]])
+ ret = self.evaluate(output)
+ self.assertAllEqual(ret, [[3], [3], [3]])
@test_util.run_in_graph_and_eager_modes()
def testEagerSingleOutputFloat32(self):
diff --git a/tensorflow/python/lib/core/py_exception_registry.cc b/tensorflow/python/lib/core/py_exception_registry.cc
index 6637de632b..d03cf8930b 100644
--- a/tensorflow/python/lib/core/py_exception_registry.cc
+++ b/tensorflow/python/lib/core/py_exception_registry.cc
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/python/lib/core/py_exception_registry.h"
-
#include <Python.h>
+#include "tensorflow/python/lib/core/py_exception_registry.h"
+
namespace tensorflow {
PyExceptionRegistry* PyExceptionRegistry::singleton_ = nullptr;
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 8c6bb7955a..30c1a9c759 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <array>
+#include <Python.h>
+
#include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -33,8 +35,6 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
-#include <Python.h>
-
namespace tensorflow {
namespace {
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 32ea737a99..386be35ba2 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) {
#endif
}
+bool IsPyDouble(PyObject* obj) {
+ return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
+}
+
bool IsPyFloat(PyObject* obj) {
return PyFloat_Check(obj) ||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
@@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
"Attempted to convert an invalid sequence to a Tensor.");
}
}
- } else if (IsPyFloat(obj)) {
+ } else if (IsPyDouble(obj)) {
*dtype = DT_DOUBLE;
+ } else if (IsPyFloat(obj)) {
+ *dtype = DT_FLOAT;
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
// Have to test for bool before int, since IsInt(True/False) == true.
*dtype = DT_BOOL;
@@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
break;
}
switch (infer_dtype) {
- case DT_DOUBLE:
+ case DT_FLOAT:
// TODO(josh11b): Handle mixed floats and complex numbers?
if (requested_dtype == DT_INVALID) {
// TensorFlow uses float32s to represent floating point numbers
@@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
// final type.
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
}
-
+ case DT_DOUBLE:
+ RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
case DT_INT64:
if (requested_dtype == DT_INVALID) {
const char* error = ConvertInt32(obj, shape, ret);
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
index 00cbf0c532..dcda1f4a44 100644
--- a/tensorflow/python/lib/core/py_util.cc
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -15,9 +15,10 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_util.h"
+#include <Python.h>
+
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
-#include <Python.h>
namespace tensorflow {
namespace {
diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h
index 32d2868886..35d71f7629 100644
--- a/tensorflow/python/lib/core/safe_ptr.h
+++ b/tensorflow/python/lib/core/safe_ptr.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <Python.h>
+
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index 59f5075f17..f22fb253e4 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -21,6 +21,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import binascii
import os
import uuid
@@ -33,6 +34,10 @@ from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+# A good default block size depends on the system in question.
+# A somewhat conservative default chosen here.
+_DEFAULT_BLOCK_SIZE = 16 * 1024 * 1024
+
class FileIO(object):
"""FileIO class that exposes methods to read / write to / from files.
@@ -551,3 +556,56 @@ def stat(filename):
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.Stat(compat.as_bytes(filename), file_statistics, status)
return file_statistics
+
+
+def filecmp(filename_a, filename_b):
+ """Compare two files, returning True if they are the same, False otherwise.
+
+ We check size first and return False quickly if the files are different sizes.
+ If they are the same size, we continue to generating a crc for the whole file.
+
+ You might wonder: why not use Python's filecmp.cmp() instead? The answer is
+ that the builtin library is not robust to the many different filesystems
+ TensorFlow runs on, and so we here perform a similar comparison with
+ the more robust FileIO.
+
+ Args:
+ filename_a: string path to the first file.
+ filename_b: string path to the second file.
+
+ Returns:
+ True if the files are the same, False otherwise.
+ """
+ size_a = FileIO(filename_a, "rb").size()
+ size_b = FileIO(filename_b, "rb").size()
+ if size_a != size_b:
+ return False
+
+ # Size is the same. Do a full check.
+ crc_a = file_crc32(filename_a)
+ crc_b = file_crc32(filename_b)
+ return crc_a == crc_b
+
+
+def file_crc32(filename, block_size=_DEFAULT_BLOCK_SIZE):
+ """Get the crc32 of the passed file.
+
+ The crc32 of a file can be used for error checking; two files with the same
+ crc32 are considered equivalent. Note that the entire file must be read
+ to produce the crc32.
+
+ Args:
+ filename: string, path to a file
+ block_size: Integer, process the files by reading blocks of `block_size`
+ bytes. Use -1 to read the file as once.
+
+ Returns:
+ hexadecimal as string, the crc32 of the passed file.
+ """
+ crc = 0
+ with FileIO(filename, mode="rb") as f:
+ chunk = f.read(n=block_size)
+ while chunk:
+ crc = binascii.crc32(chunk, crc)
+ chunk = f.read(n=block_size)
+ return hex(crc & 0xFFFFFFFF)
diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py
index 223858edfa..c21eb93103 100644
--- a/tensorflow/python/lib/io/file_io_test.py
+++ b/tensorflow/python/lib/io/file_io_test.py
@@ -491,5 +491,96 @@ class FileIoTest(test.TestCase):
v = file_io.file_exists(file_path)
self.assertEqual(v, True)
+ def testFilecmp(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.write_string_to_file(file3, u"This is another sentence\n" * 100)
+
+ self.assertFalse(file_io.filecmp(file1, file2))
+ self.assertTrue(file_io.filecmp(file2, file3))
+
+ def testFilecmpSameSize(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.write_string_to_file(file2, "This is b sentence\n" * 100)
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.write_string_to_file(file3, u"This is b sentence\n" * 100)
+
+ self.assertFalse(file_io.filecmp(file1, file2))
+ self.assertTrue(file_io.filecmp(file2, file3))
+
+ def testFilecmpBinary(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.FileIO(file1, "wb").write("testing\n\na")
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.FileIO(file2, "wb").write("testing\n\nb")
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.FileIO(file3, "wb").write("testing\n\nb")
+
+ file4 = os.path.join(self._base_dir, "file4")
+ file_io.FileIO(file4, "wb").write("testing\n\ntesting")
+
+ self.assertFalse(file_io.filecmp(file1, file2))
+ self.assertFalse(file_io.filecmp(file1, file4))
+ self.assertTrue(file_io.filecmp(file2, file3))
+
+ def testFileCrc32(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+ crc1 = file_io.file_crc32(file1)
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+ crc2 = file_io.file_crc32(file2)
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.write_string_to_file(file3, "This is another sentence\n" * 100)
+ crc3 = file_io.file_crc32(file3)
+
+ self.assertTrue(crc1 != crc2)
+ self.assertEqual(crc2, crc3)
+
+ def testFileCrc32WithBytes(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+ crc1 = file_io.file_crc32(file1, block_size=24)
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+ crc2 = file_io.file_crc32(file2, block_size=24)
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.write_string_to_file(file3, "This is another sentence\n" * 100)
+ crc3 = file_io.file_crc32(file3, block_size=-1)
+
+ self.assertTrue(crc1 != crc2)
+ self.assertEqual(crc2, crc3)
+
+ def testFileCrc32Binary(self):
+ file1 = os.path.join(self._base_dir, "file1")
+ file_io.FileIO(file1, "wb").write("testing\n\n")
+ crc1 = file_io.file_crc32(file1)
+
+ file2 = os.path.join(self._base_dir, "file2")
+ file_io.FileIO(file2, "wb").write("testing\n\n\n")
+ crc2 = file_io.file_crc32(file2)
+
+ file3 = os.path.join(self._base_dir, "file3")
+ file_io.FileIO(file3, "wb").write("testing\n\n\n")
+ crc3 = file_io.file_crc32(file3)
+
+ self.assertTrue(crc1 != crc2)
+ self.assertEqual(crc2, crc3)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
new file mode 100644
index 0000000000..a05fd15eca
--- /dev/null
+++ b/tensorflow/python/ops/collective_ops.py
@@ -0,0 +1,133 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow collective Ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import device
+from tensorflow.python.ops import gen_collective_ops
+
+
+def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
+ subdiv_offsets=(0)):
+ """Reduces tensors collectively, across devices.
+
+ Args:
+ t: the tensor to be reduced.
+ group_size: the total number of tensors to be collectively reduced.
+ Each must reside on a different device.
+ group_key: an integer identifying the group of devices.
+ instance_key: an integer identifying the participating group of Ops.
+ merge_op: string naming the binary Op to be applied to compute each
+ partial reduction.
+ final_op: string naming the unary Op to be applied to each fully
+ reduced value. Can be 'Id' for no operation.
+ subdiv_offsets: a list of integer offsets into the tensor at which each
+ independent subdivision should begin. Use [0] if no subdivision should
+ be done.
+
+ Returns:
+ An Op implementing the distributed reduction.
+
+ Raises:
+ ValueError: if any of the input parameter constraints are not met.
+ """
+ if not device.canonical_name(t.device):
+ raise ValueError('Device assignment required for collective ops')
+ if group_size <= 1:
+ raise ValueError('Parameter group_size to add_reduce must be at least 2.')
+ return gen_collective_ops.collective_reduce(t,
+ group_size=group_size,
+ group_key=group_key,
+ instance_key=instance_key,
+ merge_op=merge_op,
+ final_op=final_op,
+ subdiv_offsets=subdiv_offsets)
+
+
+def broadcast_send(t, shape, dtype, group_size, group_key, instance_key):
+ """Broadcasts one tensor to a group of others, across devices.
+
+ Args:
+ t: the tensor to be sent.
+ shape: the shape of the tensor being sent, which must agree with t.
+ dtype: the type of the tensor being sent, which must agree with t.
+ group_size: one plus the number of receiving tensors, i.e. the total
+ number of devices participating. Each tensor must reside on a
+ different device.
+ group_key: an integer identifying the group of devices.
+ instance_key: an integer identifying the participating group of Ops.
+
+ Returns:
+ An Op implementing the distributed broadcast send.
+
+ Raises:
+ ValueError: if any of the input parameter constraints are not met.
+
+ Note that the shape and dtype arguments appear redundant since they
+ should be obtainable from t. The are two reasons for including
+ them. First, the shape and type of tensors passed via broadcast must
+ be known ahead of time in their most specific form so that the receive
+ side can allocate memory for the operation and shape/type inference can
+ carry forward from there. Including the same declarations on the
+ send side clarifies a commitment already made. Secondly, having nearly
+ identical use syntax for send and receive sides may simplify tool-driven
+ generation of broadcast.
+ """
+ if not device.canonical_name(t.device):
+ raise ValueError('Device assignment required for collective ops')
+ if group_size <= 1:
+ raise ValueError(
+ 'Parameter group_size to broadcast_send must be at least 2.')
+ if t.shape != shape:
+ raise ValueError(
+ 'Shape of broadcast_send tensor not equal to delcared shape')
+ if t.dtype != dtype:
+ raise ValueError(
+ 'Type of broadcast_send tensor not equal to declared type')
+ return gen_collective_ops.collective_bcast_send(t,
+ shape=shape,
+ group_size=group_size,
+ group_key=group_key,
+ instance_key=instance_key)
+
+
+def broadcast_recv(shape, dtype, group_size, group_key, instance_key):
+ """Receives a broadcasts tensor, across devices.
+
+ Args:
+ shape: Shape of the tensor to be received.
+ dtype: Type of the tensor to be received.
+ group_size: one plus the number of receiving tensors, i.e. the total
+ number of devices participating. Each tensor must reside on a
+ different device.
+ group_key: an integer identifying the group of devices.
+ instance_key: an integer identifying the participating group of Ops.
+
+ Returns:
+ An Op implementing the broadcast receive.
+
+ Raises:
+ ValueError: if any of the input parameter constraints are not met.
+ """
+ if group_size <= 1:
+ raise ValueError(
+ 'Parameter group_size to broadcast_send must be at least 2.')
+ return gen_collective_ops.collective_bcast_recv(shape=shape,
+ T=dtype,
+ group_size=group_size,
+ group_key=group_key,
+ instance_key=instance_key)
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
new file mode 100644
index 0000000000..8e16cffdf4
--- /dev/null
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Collective Operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.platform import test
+
+# TODO(tucker): Make these ops work in eager mode. b/79776476
+
+
+class CollectiveOpTest(test.TestCase):
+
+ def _testCollectiveReduce(self, t0, t1, expected):
+ group_key = 1
+ instance_key = 1
+ with self.test_session(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
+ with ops.device('/CPU:0'):
+ in0 = constant_op.constant(t0)
+ colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key,
+ 'Add', 'Div', [0])
+ with ops.device('/CPU:1'):
+ in1 = constant_op.constant(t1)
+ colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
+ 'Add', 'Div', [0])
+ run_options = config_pb2.RunOptions()
+ run_options.experimental.collective_graph_key = 1
+ results = sess.run([colred0, colred1], options=run_options)
+ self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
+
+ def testCollectiveReduce(self):
+ self._testCollectiveReduce([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
+ [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
+ [0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
+
+ def _testCollectiveBroadcast(self, t0):
+ group_key = 1
+ instance_key = 1
+ with self.test_session(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
+ with ops.device('/CPU:0'):
+ in0 = constant_op.constant(t0)
+ out0 = collective_ops.broadcast_send(in0, in0.shape, in0.dtype,
+ 2, group_key, instance_key)
+ with ops.device('/CPU:1'):
+ c1 = constant_op.constant(t0)
+ out1 = collective_ops.broadcast_recv(c1.shape, c1.dtype,
+ 2, group_key, instance_key)
+ run_options = config_pb2.RunOptions()
+ run_options.experimental.collective_graph_key = 1
+ results = sess.run([out0, out1], options=run_options)
+ self.assertAllClose(results[0], t0, rtol=1e-5, atol=1e-5)
+ self.assertAllClose(results[1], t0, rtol=1e-5, atol=1e-5)
+
+ def testCollectiveBroadcast(self):
+ self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index d7fb3f1f78..84d9d40a35 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -71,7 +71,7 @@ class Bernoulli(distribution.Distribution):
Raises:
ValueError: If p and logits are passed, or if neither are passed.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index b697848600..f28f76b6c4 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -150,7 +150,7 @@ class Beta(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
self._concentration1 = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration1, name="concentration1"),
@@ -321,7 +321,7 @@ class BetaWithSoftplusConcentration(Beta):
validate_args=False,
allow_nan_stats=True,
name="BetaWithSoftplusConcentration"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration1,
concentration0]) as name:
super(BetaWithSoftplusConcentration, self).__init__(
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index caceadf53a..b65e64d401 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -160,13 +160,20 @@ class Bijector(object):
3. `log_det_jacobian(x)`
- "The log of the determinant of the matrix of all first-order partial
- derivatives of the inverse function."
+ "The log of the absolute value of the determinant of the matrix of all
+ first-order partial derivatives of the inverse function."
Useful for inverting a transformation to compute one probability in terms
of another. Geometrically, the Jacobian determinant is the volume of the
transformation and is used to scale the probability.
+ We take the absolute value of the determinant before log to avoid NaN
+ values. Geometrically, a negative determinant corresponds to an
+ orientation-reversing transformation. It is ok for us to discard the sign
+ of the determinant because we only integrate everywhere-nonnegative
+ functions (probability densities) and the correct orientation is always the
+ one that produces a nonnegative integrand.
+
By convention, transformations of random variables are named in terms of the
forward transformation. The forward transformation creates samples, the
inverse is useful for computing probabilities.
@@ -1021,7 +1028,7 @@ class Bijector(object):
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
# The multiplication by ones can change the inferred static shape so we try
# to recover as much as possible.
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
if (event_ndims_ is not None and
y.shape.ndims is not None and
ildj.shape.ndims is not None):
@@ -1036,7 +1043,7 @@ class Bijector(object):
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
"""Compute the reduction dimensions given event_ndims."""
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
if event_ndims_ is not None:
return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
@@ -1046,9 +1053,18 @@ class Bijector(object):
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ event_ndims_ = tensor_util.constant_value(event_ndims)
assertions = []
+
+ if not event_ndims.dtype.is_integer:
+ raise ValueError("Expected integer dtype, got dtype {}".format(
+ event_ndims.dtype))
+
if event_ndims_ is not None:
+ if event_ndims.shape.ndims != 0:
+ raise ValueError("Expected scalar event_ndims, got shape {}".format(
+ event_ndims.shape))
if min_event_ndims > event_ndims_:
raise ValueError("event_ndims ({}) must be larger than "
"min_event_ndims ({})".format(
@@ -1056,17 +1072,29 @@ class Bijector(object):
elif self.validate_args:
assertions += [
check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
+
+ if event_ndims.shape.is_fully_defined():
+ if event_ndims.shape.ndims != 0:
+ raise ValueError("Expected scalar shape, got ndims {}".format(
+ event_ndims.shape.ndims))
+
+ elif self.validate_args:
+ assertions += [
+ check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
return assertions
- def _maybe_get_event_ndims_statically(self, event_ndims):
+ def _maybe_get_static_event_ndims(self, event_ndims):
"""Helper which returns tries to return an integer static value."""
event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- if isinstance(event_ndims_, np.ndarray):
- if (event_ndims_.dtype not in (np.int32, np.int64) or
- len(event_ndims_.shape)):
+ if isinstance(event_ndims_, (np.generic, np.ndarray)):
+ if event_ndims_.dtype not in (np.int32, np.int64):
+ raise ValueError("Expected integer dtype, got dtype {}".format(
+ event_ndims_.dtype))
+
+ if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
raise ValueError("Expected a scalar integer, got {}".format(
event_ndims_))
- event_ndims_ = event_ndims_.tolist()
+ event_ndims_ = int(event_ndims_)
return event_ndims_
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index bbdc8c455a..b88a0518b6 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -182,7 +182,7 @@ class Categorical(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[logits, probs]) as name:
self._logits, self._probs = distribution_util.get_logits_and_probs(
logits=logits,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 8d0d1d860b..1ab58c1450 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -154,7 +154,7 @@ class Dirichlet(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration]) as name:
self._concentration = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration, name="concentration"),
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index 3a35e0caa0..5350c82847 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -191,7 +191,7 @@ class DirichletMultinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[total_count, concentration]) as name:
# Broadcasting works because:
# * The broadcasting convention is to prepend dimensions of size [1], and
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index a6579e3246..0db4749507 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -525,7 +525,7 @@ class Distribution(_BaseDistribution):
"""Dictionary of parameters used to instantiate this `Distribution`."""
# Remove "self", "__class__", or other special variables. These can appear
# if the subclass used:
- # `parameters = distribution_util.parent_frame_arguments()`.
+ # `parameters = dict(locals())`.
return dict((k, v) for k, v in self._parameters.items()
if not k.startswith("__") and k != "self")
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index 1e08f48d52..24bc3f3d3e 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
-from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -91,7 +90,7 @@ class Exponential(gamma.Gamma):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
# Even though all statistics of are defined for valid inputs, this is not
# true in the parent class "Gamma." Therefore, passing
# allow_nan_stats=True
@@ -144,7 +143,7 @@ class ExponentialWithSoftplusRate(Exponential):
validate_args=False,
allow_nan_stats=True,
name="ExponentialWithSoftplusRate"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[rate]) as name:
super(ExponentialWithSoftplusRate, self).__init__(
rate=nn.softplus(rate, name="softplus_rate"),
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 7ca690d9d2..163a27f758 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -126,7 +126,7 @@ class Gamma(distribution.Distribution):
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
@@ -261,7 +261,7 @@ class GammaWithSoftplusConcentrationRate(Gamma):
validate_args=False,
allow_nan_stats=True,
name="GammaWithSoftplusConcentrationRate"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[concentration, rate]) as name:
super(GammaWithSoftplusConcentrationRate, self).__init__(
concentration=nn.softplus(concentration,
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index ee3a6a40ff..be17cf2527 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
-from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -101,7 +100,7 @@ class Laplace(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` are of different dtype.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -218,7 +217,7 @@ class LaplaceWithSoftplusScale(Laplace):
validate_args=False,
allow_nan_stats=True,
name="LaplaceWithSoftplusScale"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
super(LaplaceWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index 036ba45ccc..d0943e8eee 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -182,7 +182,7 @@ class Multinomial(distribution.Distribution):
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[total_count, logits, probs]) as name:
self._total_count = ops.convert_to_tensor(total_count, name="total_count")
if validate_args:
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 0620aae10d..d0a987ba7c 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,7 +32,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
-from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -132,7 +131,7 @@ class Normal(distribution.Distribution):
Raises:
TypeError: if `loc` and `scale` have different `dtype`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
@@ -244,7 +243,7 @@ class NormalWithSoftplusScale(Normal):
validate_args=False,
allow_nan_stats=True,
name="NormalWithSoftplusScale"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[scale]) as name:
super(NormalWithSoftplusScale, self).__init__(
loc=loc,
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index 9330b930b5..20a2d16181 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -157,7 +157,7 @@ class StudentT(distribution.Distribution):
Raises:
TypeError: if loc and scale are different dtypes.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[df, loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(df)]
if validate_args else []):
@@ -349,7 +349,7 @@ class StudentTWithAbsDfSoftplusScale(StudentT):
validate_args=False,
allow_nan_stats=True,
name="StudentTWithAbsDfSoftplusScale"):
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[df, scale]) as name:
super(StudentTWithAbsDfSoftplusScale, self).__init__(
df=math_ops.floor(math_ops.abs(df)),
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index 9392464ec1..e80bf9ee42 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -252,7 +252,7 @@ class TransformedDistribution(distribution_lib.Distribution):
name: Python `str` name prefixed to Ops created by this class. Default:
`bijector.name + distribution.name`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
@@ -416,7 +416,7 @@ class TransformedDistribution(distribution_lib.Distribution):
# For caching to work, it is imperative that the bijector is the first to
# modify the input.
x = self.bijector.inverse(y)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
@@ -435,13 +435,15 @@ class TransformedDistribution(distribution_lib.Distribution):
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
log_prob += math_ops.cast(ildj, log_prob.dtype)
if self._is_maybe_event_override and isinstance(event_ndims, int):
- log_prob.set_shape(array_ops.broadcast_static_shape(
- x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+ log_prob.set_shape(
+ array_ops.broadcast_static_shape(
+ y.get_shape().with_rank_at_least(1)[:-event_ndims],
+ self.batch_shape))
return log_prob
def _prob(self, y):
x = self.bijector.inverse(y)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
@@ -459,8 +461,10 @@ class TransformedDistribution(distribution_lib.Distribution):
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
if self._is_maybe_event_override and isinstance(event_ndims, int):
- prob.set_shape(array_ops.broadcast_static_shape(
- y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+ prob.set_shape(
+ array_ops.broadcast_static_shape(
+ y.get_shape().with_rank_at_least(1)[:-event_ndims],
+ self.batch_shape))
return prob
def _log_cdf(self, y):
@@ -618,15 +622,14 @@ class TransformedDistribution(distribution_lib.Distribution):
return array_ops.transpose(
x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
- def _maybe_get_event_ndims_statically(self):
+ def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- static_event_ndims = tensor_util.constant_value(event_ndims)
-
- if static_event_ndims is not None:
- return static_event_ndims
+ if event_ndims_ is not None:
+ return event_ndims_
return event_ndims
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index dfa10331e3..e66c4a37e7 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,7 +29,6 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
@@ -103,7 +102,7 @@ class Uniform(distribution.Distribution):
Raises:
InvalidArgumentError: if `low >= high` and `validate_args=False`.
"""
- parameters = distribution_util.parent_frame_arguments()
+ parameters = dict(locals())
with ops.name_scope(name, values=[low, high]) as name:
with ops.control_dependencies([
check_ops.assert_less(
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 59c89d21f9..728fda28c2 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -179,6 +179,7 @@ def maybe_get_static_value(x, dtype=None):
if x is None:
return x
try:
+ # This returns an np.ndarray.
x_ = tensor_util.constant_value(x)
except TypeError:
x_ = x
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 716b54f07c..7385cb7585 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -1006,21 +1006,32 @@ def _AggregatedGrads(grads,
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else:
- out_grad = math_ops._as_indexed_slices_list(
- [g for g in out_grad if g is not None])
- out_grad = [_HandleNestedIndexedSlices(x) for x in out_grad]
- # Form IndexedSlices out of the concatenated values and
- # indices.
- out_grads[i] = ops.IndexedSlices(
- array_ops.concat([x.values for x in out_grad], 0),
- array_ops.concat([x.indices for x in out_grad], 0),
- out_grad[0].dense_shape)
+ out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
else: # not out_grad
# out_grads[i] is [], thus its aggregation is simply None.
out_grads[i] = None
return out_grads
+def _AggregateIndexedSlicesGradients(grads):
+ """Aggregates gradients of type `IndexedSlices` by concatenation."""
+ if len(grads) < 1:
+ return None
+ elif len(grads) == 1:
+ return grads[0]
+ else:
+ grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access
+ [g for g in grads if g is not None])
+ grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access
+ # Form IndexedSlices out of the concatenated values and indices.
+ concat_grad = ops.IndexedSlices(
+ array_ops.concat([x.values for x in grads], axis=0),
+ array_ops.concat([x.indices for x in grads], axis=0),
+ grads[0].dense_shape)
+
+ return concat_grad
+
+
# TODO(vrv): Make this available when we want to make it public.
def _hessian_vector_product(ys, xs, v):
"""Multiply the Hessian of `ys` wrt `xs` by `v`.
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 70d500a108..6891501ae1 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -946,5 +946,53 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
+class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
+
+ def _assert_indexed_slices_equal(self, left, right):
+ self.assertAllEqual(
+ self.evaluate(ops.convert_to_tensor(left)),
+ self.evaluate(ops.convert_to_tensor(right)))
+
+ def testNoGradients(self):
+ self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
+
+ def testOneGradient(self):
+ t = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ result = gradients_impl._AggregateIndexedSlicesGradients([t])
+ self._assert_indexed_slices_equal(t, result)
+
+ def testMultipleGradients(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMultipleGradientsWithNones(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = math_ops._as_indexed_slices(constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]]))
+ t3 = None
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
+ self._assert_indexed_slices_equal(total, result)
+
+ def testMixedTensorAndIndexedSlices(self):
+ t0 = math_ops._as_indexed_slices(constant_op.constant(
+ [[1., 2.], [0, 0], [3., 4.]]))
+ t1 = constant_op.constant(
+ [[0., 0.], [5, 6], [7., 8.]])
+ total = constant_op.constant(
+ [[1., 2.], [5, 6], [10., 12.]])
+ result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ self._assert_indexed_slices_equal(total, result)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 3f40e3ff75..e907fc470b 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -523,7 +523,7 @@ def transpose_image(image):
@tf_export('image.central_crop')
def central_crop(image, central_fraction):
- """Crop the central region of the image.
+ """Crop the central region of the image(s).
Remove the outer parts of an image but retain the central region of the image
along each dimension. If we specify central_fraction = 0.5, this function
@@ -536,15 +536,19 @@ def central_crop(image, central_fraction):
| | where "X" is the central 50% of the image.
--------
+ This function works on either a single image (`image` is a 3-D Tensor), or a
+ batch of images (`image` is a 4-D Tensor).
+
Args:
- image: 3-D float Tensor of shape [height, width, depth]
+ image: Either a 3-D float Tensor of shape [height, width, depth], or a 4-D
+ Tensor of shape [batch_size, height, width, depth].
central_fraction: float (0, 1], fraction of size to crop
Raises:
ValueError: if central_crop_fraction is not within (0, 1].
Returns:
- 3-D float Tensor
+ 3-D / 4-D float Tensor, as per the input.
"""
with ops.name_scope(None, 'central_crop', [image]):
image = ops.convert_to_tensor(image, name='image')
@@ -553,24 +557,75 @@ def central_crop(image, central_fraction):
if central_fraction == 1.0:
return image
- image = _Assert3DImage(image)
+ _AssertAtLeast3DImage(image)
+ rank = image.get_shape().ndims
+ if rank != 3 and rank != 4:
+ raise ValueError('`image` should either be a Tensor with rank = 3 or '
+ 'rank = 4. Had rank = {}.'.format(rank))
+
+ # Helper method to return the `idx`-th dimension of `tensor`, along with
+ # a boolean signifying if the dimension is dynamic.
+ def _get_dim(tensor, idx):
+ static_shape = tensor.get_shape()[idx].value
+ if static_shape is not None:
+ return static_shape, False
+ return array_ops.shape(tensor)[idx], True
+
+ # Get the height, width, depth (and batch size, if the image is a 4-D
+ # tensor).
+ if rank == 3:
+ img_h, dynamic_h = _get_dim(image, 0)
+ img_w, dynamic_w = _get_dim(image, 1)
+ img_d = image.get_shape()[2]
+ else:
+ img_bs = image.get_shape()[0]
+ img_h, dynamic_h = _get_dim(image, 1)
+ img_w, dynamic_w = _get_dim(image, 2)
+ img_d = image.get_shape()[3]
+
+ # Compute the bounding boxes for the crop. The type and value of the
+ # bounding boxes depend on the `image` tensor's rank and whether / not the
+ # dimensions are statically defined.
+ if dynamic_h:
+ img_hd = math_ops.to_double(img_h)
+ bbox_h_start = math_ops.to_int32((img_hd - img_hd * central_fraction) / 2)
+ else:
+ img_hd = float(img_h)
+ bbox_h_start = int((img_hd - img_hd * central_fraction) / 2)
- img_shape = array_ops.shape(image)
- depth = image.get_shape()[2]
- img_h = math_ops.to_double(img_shape[0])
- img_w = math_ops.to_double(img_shape[1])
- bbox_h_start = math_ops.to_int32((img_h - img_h * central_fraction) / 2)
- bbox_w_start = math_ops.to_int32((img_w - img_w * central_fraction) / 2)
+ if dynamic_w:
+ img_wd = math_ops.to_double(img_w)
+ bbox_w_start = math_ops.to_int32((img_wd - img_wd * central_fraction) / 2)
+ else:
+ img_wd = float(img_w)
+ bbox_w_start = int((img_wd - img_wd * central_fraction) / 2)
+
+ bbox_h_size = img_h - bbox_h_start * 2
+ bbox_w_size = img_w - bbox_w_start * 2
- bbox_h_size = img_shape[0] - bbox_h_start * 2
- bbox_w_size = img_shape[1] - bbox_w_start * 2
+ if rank == 3:
+ bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0])
+ bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1])
+ else:
+ bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0])
+ bbox_size = array_ops.stack([-1, bbox_h_size, bbox_w_size, -1])
- bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0])
- bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1])
image = array_ops.slice(image, bbox_begin, bbox_size)
- # The first two dimensions are dynamic and unknown.
- image.set_shape([None, None, depth])
+ # Reshape the `image` tensor to the desired size.
+ if rank == 3:
+ image.set_shape([
+ None if dynamic_h else bbox_h_size,
+ None if dynamic_w else bbox_w_size,
+ img_d
+ ])
+ else:
+ image.set_shape([
+ img_bs,
+ None if dynamic_h else bbox_h_size,
+ None if dynamic_w else bbox_w_size,
+ img_d
+ ])
return image
@@ -1772,7 +1827,7 @@ def non_max_suppression(boxes,
scores,
max_output_size,
iou_threshold=0.5,
- score_threshold=0.0,
+ score_threshold=float('-inf'),
name=None):
"""Greedily selects a subset of bounding boxes in descending order of score.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index c437c12c27..72c889a2e6 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -1585,14 +1585,16 @@ class CentralCropTest(test_util.TensorFlowTestCase):
self.assertEqual(y.get_shape().as_list(), post_shape)
def testNoOp(self):
- x_shape = [13, 9, 3]
- x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- x = constant_op.constant(x_np, shape=x_shape)
- y = image_ops.central_crop(x, 1.0)
- y_tf = y.eval()
- self.assertAllEqual(y_tf, x_np)
- self.assertEqual(y.op.name, x.op.name)
+ x_shapes = [[13, 9, 3], [5, 13, 9, 3]]
+ for x_shape in x_shapes:
+ x_np = np.ones(x_shape, dtype=np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.central_crop(x, 1.0)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+ self.assertEqual(y.op.name, x.op.name)
def testCropping(self):
x_shape = [4, 8, 1]
@@ -1601,6 +1603,23 @@ class CentralCropTest(test_util.TensorFlowTestCase):
[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8]],
dtype=np.int32).reshape(x_shape)
y_np = np.array([[3, 4, 5, 6], [3, 4, 5, 6]]).reshape([2, 4, 1])
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.central_crop(x, 0.5)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+ self.assertAllEqual(y_tf.shape, y_np.shape)
+
+ x_shape = [2, 4, 8, 1]
+ x_np = np.array(
+ [[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8],
+ [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8],
+ [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1],
+ [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1]],
+ dtype=np.int32).reshape(x_shape)
+ y_np = np.array([[[3, 4, 5, 6], [3, 4, 5, 6]],
+ [[6, 5, 4, 3], [6, 5, 4, 3]]]).reshape([2, 2, 4, 1])
with self.test_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 0.5)
@@ -1610,52 +1629,87 @@ class CentralCropTest(test_util.TensorFlowTestCase):
def testCropping2(self):
# Test case for 10315
- x_shape = [240, 320, 3]
- x_np = np.zeros(x_shape, dtype=np.int32)
- y_np = np.zeros([80, 106, 3], dtype=np.int32)
- with self.test_session(use_gpu=True):
- x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
- y = image_ops.central_crop(x, 0.33)
- y_tf = y.eval(feed_dict={x: x_np})
- self.assertAllEqual(y_tf, y_np)
- self.assertAllEqual(y_tf.shape, y_np.shape)
+ x_shapes = [[240, 320, 3], [5, 240, 320, 3]]
+ expected_y_shapes = [[80, 106, 3], [5, 80, 106, 3]]
+
+ for x_shape, y_shape in zip(x_shapes, expected_y_shapes):
+ x_np = np.zeros(x_shape, dtype=np.int32)
+ y_np = np.zeros(y_shape, dtype=np.int32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
+ y = image_ops.central_crop(x, 0.33)
+ y_tf = y.eval(feed_dict={x: x_np})
+ self.assertAllEqual(y_tf, y_np)
+ self.assertAllEqual(y_tf.shape, y_np.shape)
def testShapeInference(self):
- # Test no-op fraction=1.0
+ # Test no-op fraction=1.0, with 3-D tensors.
self._assertShapeInference([50, 60, 3], 1.0, [50, 60, 3])
self._assertShapeInference([None, 60, 3], 1.0, [None, 60, 3])
self._assertShapeInference([50, None, 3], 1.0, [50, None, 3])
self._assertShapeInference([None, None, 3], 1.0, [None, None, 3])
self._assertShapeInference([50, 60, None], 1.0, [50, 60, None])
self._assertShapeInference([None, None, None], 1.0, [None, None, None])
- self._assertShapeInference(None, 1.0, None)
- # TODO(toddw): Currently central_crop() doesn't infer the result shape even
- # when it's possible. If we change it to do so, we can test as follows:
- #
- # self._assertShapeInference([50, 60, 3], 0.5, [25, 30, 3])
- # self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
- # self._assertShapeInference([50, None, 3], 0.5, [25, None, 3])
- # self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
- # self._assertShapeInference([50, 60, None], 0.5, [25, 30, None])
- # self._assertShapeInference([None, None, None], 0.5, [None, None, None])
- # self._assertShapeInference(None, 0.5, None)
- def testError(self):
+ # Test no-op fraction=0.5, with 3-D tensors.
+ self._assertShapeInference([50, 60, 3], 0.5, [26, 30, 3])
+ self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
+ self._assertShapeInference([50, None, 3], 0.5, [26, None, 3])
+ self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
+ self._assertShapeInference([50, 60, None], 0.5, [26, 30, None])
+ self._assertShapeInference([None, None, None], 0.5, [None, None, None])
+
+ # Test no-op fraction=1.0, with 4-D tensors.
+ self._assertShapeInference([5, 50, 60, 3], 1.0, [5, 50, 60, 3])
+ self._assertShapeInference([5, None, 60, 3], 1.0, [5, None, 60, 3])
+ self._assertShapeInference([5, 50, None, 3], 1.0, [5, 50, None, 3])
+ self._assertShapeInference([5, None, None, 3], 1.0, [5, None, None, 3])
+ self._assertShapeInference([5, 50, 60, None], 1.0, [5, 50, 60, None])
+ self._assertShapeInference([5, None, None, None], 1.0,
+ [5, None, None, None])
+ self._assertShapeInference([None, None, None, None], 1.0,
+ [None, None, None, None])
+
+ # Test no-op fraction=0.5, with 4-D tensors.
+ self._assertShapeInference([5, 50, 60, 3], 0.5, [5, 26, 30, 3])
+ self._assertShapeInference([5, None, 60, 3], 0.5, [5, None, 30, 3])
+ self._assertShapeInference([5, 50, None, 3], 0.5, [5, 26, None, 3])
+ self._assertShapeInference([5, None, None, 3], 0.5, [5, None, None, 3])
+ self._assertShapeInference([5, 50, 60, None], 0.5, [5, 26, 30, None])
+ self._assertShapeInference([5, None, None, None], 0.5,
+ [5, None, None, None])
+ self._assertShapeInference([None, None, None, None], 0.5,
+ [None, None, None, None])
+
+ def testErrorOnInvalidCentralCropFractionValues(self):
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- x = constant_op.constant(x_np, shape=x_shape)
- with self.assertRaises(ValueError):
- _ = image_ops.central_crop(x, 0.0)
- with self.assertRaises(ValueError):
- _ = image_ops.central_crop(x, 1.01)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 0.0)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 1.01)
+
+ def testErrorOnInvalidShapes(self):
+ x_shapes = [None, [], [3], [3, 9], [3, 9, 3, 9, 3]]
+ for x_shape in x_shapes:
+ x_np = np.ones(x_shape, dtype=np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 0.5)
def testNameScope(self):
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- y = image_ops.central_crop(x_np, 1.0)
- self.assertTrue(y.op.name.startswith("central_crop"))
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ y = image_ops.central_crop(x_np, 1.0)
+ self.assertTrue(y.op.name.startswith("central_crop"))
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 54b08a564b..0c2f5b06c4 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2311,13 +2311,22 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
raise ValueError("keep_prob must be a scalar tensor or a float in the "
"range (0, 1], got %g" % keep_prob)
- keep_prob = ops.convert_to_tensor(
- keep_prob, dtype=x.dtype, name="keep_prob")
- keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
- # Do nothing if we know keep_prob == 1
- if tensor_util.constant_value(keep_prob) == 1:
+ # Early return if nothing needs to be dropped.
+ if isinstance(keep_prob, float) and keep_prob == 1:
return x
+ if context.executing_eagerly():
+ if isinstance(keep_prob, ops.EagerTensor):
+ if keep_prob.numpy() == 1:
+ return x
+ else:
+ keep_prob = ops.convert_to_tensor(
+ keep_prob, dtype=x.dtype, name="keep_prob")
+ keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+ # Do nothing if we know keep_prob == 1
+ if tensor_util.constant_value(keep_prob) == 1:
+ return x
noise_shape = _get_noise_shape(x, noise_shape)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 294ee0e328..d88fd836f5 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -123,6 +123,30 @@ class Variable(checkpointable.CheckpointableBase):
various `Optimizer` classes use this collection as the default list of
variables to optimize.
+ WARNING: tf.Variable objects have a non-intuitive memory model. A Variable is
+ represented internally as a mutable Tensor which can non-deterministically
+ alias other Tensors in a graph. The set of operations which consume a Variable
+ and can lead to aliasing is undetermined and can change across TensorFlow
+ versions. Avoid writing code which relies on the value of a Variable either
+ changing or not changing as other operations happen. For example, using
+ Variable objects or simple functions thereof as predicates in a `tf.cond` is
+ dangerous and error-prone:
+
+ ```
+ v = tf.Variable(True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
+ ```
+
+ Here replacing tf.Variable with tf.contrib.eager.Variable will fix any
+ nondeterminism issues.
+
+ To use the replacement for variables which does
+ not have these issues:
+
+ * Replace `tf.Variable` with `tf.contrib.eager.Variable`;
+ * Call `tf.get_variable_scope().set_use_resource(True)` inside a
+ `tf.variable_scope` before the `tf.get_variable()` call.
+
@compatibility(eager)
`tf.Variable` is not compatible with eager execution. Use
`tf.contrib.eager.Variable` instead which is compatible with both eager
@@ -235,7 +259,7 @@ class Variable(checkpointable.CheckpointableBase):
constraint=constraint)
def __repr__(self):
- if context.executing_eagerly():
+ if context.executing_eagerly() and not self._in_graph_mode:
return "<tf.Variable '%s' shape=%s dtype=%s, numpy=%s>" % (
self.name, self.get_shape(), self.dtype.name,
ops.numpy_text(self.read_value(), is_repr=True))
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5f1fafb9dc..500dc30cc3 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -33,8 +33,9 @@ limitations under the License.
%rename("%s") TFE_ContextAsyncClearError;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
+%rename("%s") TFE_Py_SetEagerTensorProfiler;
%rename("%s") TFE_Py_RegisterExceptionClass;
-%rename("%s") TFE_Py_RegisterBackwardFunctionGetter;
+%rename("%s") TFE_Py_RegisterGradientFunction;
%rename("%s") TFE_Py_RegisterFallbackExceptionClass;
%rename("%s") TFE_Py_RegisterResourceVariableType;
%rename("%s") TFE_Py_Execute;
@@ -60,6 +61,7 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
+%rename("%s") TFE_Py_TensorShapeOnDevice;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 8f1d5a099f..24a13c0f33 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -104,10 +104,10 @@ class SavedModelBuilder(object):
Args:
assets_collection_to_add: The collection where the asset paths are setup.
"""
- asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add)
+ asset_filename_map = _maybe_save_assets(assets_collection_to_add)
# Return if there are no assets to write.
- if len(asset_source_filepath_list) is 0:
+ if not asset_filename_map:
tf_logging.info("No assets to write.")
return
@@ -119,12 +119,10 @@ class SavedModelBuilder(object):
file_io.recursive_create_dir(assets_destination_dir)
# Copy each asset from source path to destination path.
- for asset_source_filepath in asset_source_filepath_list:
- asset_source_filename = os.path.basename(asset_source_filepath)
-
+ for asset_basename, asset_source_filepath in asset_filename_map.items():
asset_destination_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
- compat.as_bytes(asset_source_filename))
+ compat.as_bytes(asset_basename))
# Only copy the asset file to the destination if it does not already
# exist. This is to ensure that an asset with the same name defined as
@@ -476,16 +474,17 @@ def _maybe_save_assets(assets_collection_to_add=None):
assets_collection_to_add: The collection where the asset paths are setup.
Returns:
- The list of filepaths to the assets in the assets collection.
+ A dict of asset basenames for saving to the original full path to the asset.
Raises:
ValueError: Indicating an invalid filepath tensor.
"""
- asset_source_filepath_list = []
+ # Map of target file names to original filenames
+ asset_filename_map = {}
if assets_collection_to_add is None:
tf_logging.info("No assets to save.")
- return asset_source_filepath_list
+ return asset_filename_map
# Iterate over the supplied asset collection, build the `AssetFile` proto
# and add them to the collection with key `constants.ASSETS_KEY`, in the
@@ -495,15 +494,71 @@ def _maybe_save_assets(assets_collection_to_add=None):
if not asset_source_filepath:
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
- asset_source_filename = os.path.basename(asset_source_filepath)
+ asset_filename = _get_asset_filename_to_add(
+ asset_source_filepath, asset_filename_map)
# Build `AssetFile` proto and add it to the asset collection in the graph.
- _add_asset_to_collection(asset_source_filename, asset_tensor)
+ # Note that this should be done even when the file is a duplicate of an
+ # already-added file, as the tensor reference should still exist.
+ _add_asset_to_collection(asset_filename, asset_tensor)
- asset_source_filepath_list.append(asset_source_filepath)
+ # In the cases where we are adding a duplicate, this will result in the
+ # last of the filepaths being the one used for copying the file to the
+ # SavedModel. Since the files in question are the same, it doesn't matter
+ # either way.
+ asset_filename_map[asset_filename] = asset_source_filepath
tf_logging.info("Assets added to graph.")
- return asset_source_filepath_list
+ return asset_filename_map
+
+
+def _get_asset_filename_to_add(asset_filepath, asset_filename_map):
+ """Get a unique basename to add to the SavedModel if this file is unseen.
+
+ Assets come from users as full paths, and we save them out to the
+ SavedModel as basenames. In some cases, the basenames collide. Here,
+ we dedupe asset basenames by first checking if the file is the same,
+ and, if different, generate and return an index-suffixed basename
+ that can be used to add the asset to the SavedModel.
+
+ Args:
+ asset_filepath: the full path to the asset that is being saved
+ asset_filename_map: a dict of filenames used for saving the asset in
+ the SavedModel to full paths from which the filenames were derived.
+
+ Returns:
+ Uniquified filename string if the file is not a duplicate, or the original
+ filename if the file has already been seen and saved.
+ """
+ asset_filename = os.path.basename(asset_filepath)
+
+ if asset_filename not in asset_filename_map:
+ # This is an unseen asset. Safe to add.
+ return asset_filename
+
+ other_asset_filepath = asset_filename_map[asset_filename]
+ if other_asset_filepath == asset_filepath:
+ # This is the same file, stored twice in the collection list. No need
+ # to make unique.
+ return asset_filename
+
+ # Else, asset_filename is in the map, and the filepath is different. Dedupe.
+ if not file_io.filecmp(asset_filepath, other_asset_filepath):
+ # Files are different; dedupe filenames.
+ return _get_unique_asset_filename(asset_filename, asset_filename_map)
+
+ # Files are the same; don't make unique.
+ return asset_filename
+
+
+def _get_unique_asset_filename(asset_filename, asset_filename_map):
+ i = 1
+ unique_filename = asset_filename
+ while unique_filename in asset_filename_map:
+ unique_filename = compat.as_bytes("_").join(
+ [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
+ i += 1
+ return unique_filename
def _asset_path_from_tensor(path_tensor):
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 1b83d60df9..7302c77ad5 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -64,9 +64,12 @@ class SavedModelTest(test.TestCase):
self.assertEqual(variable_value, v.eval())
def _build_asset_collection(self, asset_file_name, asset_file_contents,
- asset_file_tensor_name):
+ asset_file_tensor_name, asset_subdir=""):
+ parent_dir = os.path.join(
+ compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir))
+ file_io.recursive_create_dir(parent_dir)
asset_filepath = os.path.join(
- compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name))
+ compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name))
file_io.write_string_to_file(asset_filepath, asset_file_contents)
asset_file_tensor = constant_op.constant(
asset_filepath, name=asset_file_tensor_name)
@@ -77,10 +80,11 @@ class SavedModelTest(test.TestCase):
def _validate_asset_collection(self, export_dir, graph_collection_def,
expected_asset_file_name,
expected_asset_file_contents,
- expected_asset_tensor_name):
+ expected_asset_tensor_name,
+ asset_id=0):
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
asset = meta_graph_pb2.AssetFileDef()
- assets_any[0].Unpack(asset)
+ assets_any[asset_id].Unpack(asset)
assets_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY),
@@ -634,6 +638,141 @@ class SavedModelTest(test.TestCase):
compat.as_bytes("ignored.txt"))
self.assertFalse(file_io.file_exists(ignored_asset_path))
+ def testAssetsNameCollisionDiffFile(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar bak", "asset_file_tensor",
+ asset_subdir="1")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+ asset_subdir="2")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar bak",
+ "asset_file_tensor:0")
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt_1", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+
+ def testAssetsNameCollisionSameFilepath(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_same_path")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor:0")
+ # The second tensor should be recorded, but the same.
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+ ignored_asset_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes("hello42.txt_1"))
+ self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+ def testAssetsNameCollisionSameFile(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_same_file")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor",
+ asset_subdir="1")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+ asset_subdir="2")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor:0")
+ # The second tensor should be recorded, but the same.
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+ ignored_asset_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes("hello42.txt_1"))
+ self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+ def testAssetsNameCollisionManyFiles(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_many_files")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ for i in range(5):
+ idx = str(i)
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx,
+ asset_subdir=idx)
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ for i in range(1, 5):
+ idx = str(i)
+ self._validate_asset_collection(
+ export_dir, foo_graph.collection_def, "hello42.txt_" + idx,
+ "foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx),
+ asset_id=i)
+
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz 0",
+ "asset_file_tensor_0:0")
+
def testCustomMainOp(self):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 9be8b6aafe..bc68f24c6f 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -180,11 +180,10 @@ class AdamOptimizerTest(test.TestCase):
self.assertIn(beta1_power, opt_variables)
self.assertIn(beta2_power, opt_variables)
- with ops.Graph().as_default():
- # Shouldn't return non-slot variables from other graphs.
- self.assertEqual(0, len(opt.variables()))
-
if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index df528d54d6..9b40817f55 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -336,6 +336,8 @@ class CheckpointSaverListener(object):
def after_save(self, session, global_step_value):
print('Done writing checkpoint.')
+ if decided_to_stop_training():
+ return True
def end(self, session, global_step_value):
print('Done with the session.')
@@ -354,6 +356,11 @@ class CheckpointSaverListener(object):
implementors should implement the `end()` method to handle actions related to
the last checkpoint save. But the listener should not act twice if
`after_save()` already handled this last checkpoint save.
+
+ A `CheckpointSaverListener` can request training to be stopped, by returning
+ True in `after_save`. Please note that, in replicated distributed training
+ setting, only `chief` should use this behavior. Otherwise each worker will do
+ their own evaluation, which may be wasteful of resources.
"""
def begin(self):
@@ -453,7 +460,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
- self._save(run_context.session, global_step)
+ if self._save(run_context.session, global_step):
+ run_context.request_stop()
def end(self, session):
last_step = session.run(self._global_step_tensor)
@@ -463,7 +471,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
l.end(session, last_step)
def _save(self, session, step):
- """Saves the latest checkpoint."""
+ """Saves the latest checkpoint, returns should_stop."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
@@ -475,8 +483,14 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
+ should_stop = False
for l in self._listeners:
- l.after_save(session, step)
+ if l.after_save(session, step):
+ logging.info(
+ "A CheckpointSaverListener requested that training be stopped. "
+ "listener: {}".format(l))
+ should_stop = True
+ return should_stop
def _get_saver(self):
if self._saver is not None:
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 7344ce2758..21c584f2ee 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -58,6 +58,7 @@ class MockCheckpointSaverListener(
self.before_save_count = 0
self.after_save_count = 0
self.end_count = 0
+ self.ask_for_stop = False
def begin(self):
self.begin_count += 1
@@ -67,6 +68,8 @@ class MockCheckpointSaverListener(
def after_save(self, session, global_step):
self.after_save_count += 1
+ if self.ask_for_stop:
+ return True
def end(self, session, global_step):
self.end_count += 1
@@ -471,6 +474,25 @@ class CheckpointSaverHookTest(test.TestCase):
'end': 1
}, listener_counts)
+ def test_listener_stops_training_in_after_save(self):
+ with ops.Graph().as_default():
+ scaffold = monitored_session.Scaffold()
+ variables.get_or_create_global_step()
+ train_op = training_util._increment_global_step(1)
+ listener = MockCheckpointSaverListener()
+ hook = basic_session_run_hooks.CheckpointSaverHook(
+ self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])
+ with monitored_session.SingularMonitoredSession(
+ hooks=[hook], scaffold=scaffold,
+ checkpoint_dir=self.model_dir) as sess:
+ sess.run(train_op)
+ self.assertFalse(sess.should_stop())
+ sess.run(train_op)
+ self.assertFalse(sess.should_stop())
+ listener.ask_for_stop = True
+ sess.run(train_op)
+ self.assertTrue(sess.should_stop())
+
def test_listener_with_default_saver(self):
with ops.Graph().as_default():
global_step = variables.get_or_create_global_step()
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index a7ae6e50a9..87ba4dc91c 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -22,8 +22,9 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops_gen",
- "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
"//tensorflow/python:saveable_object",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
@@ -41,6 +42,42 @@ py_test(
)
py_library(
+ name = "data_structures_base",
+ srcs = ["data_structures_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ],
+)
+
+py_library(
+ name = "data_structures",
+ srcs = ["data_structures.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ":data_structures_base",
+ ],
+)
+
+py_test(
+ name = "data_structures_test",
+ srcs = ["data_structures_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":data_structures",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ ],
+)
+
+py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index e378f0e898..cfe7259e1b 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -591,11 +591,11 @@ class CheckpointableBase(object):
self._unconditional_checkpoint_dependencies):
if name == old_name:
self._unconditional_checkpoint_dependencies[index] = new_reference
- else:
+ elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
-
- self._unconditional_dependency_names[name] = checkpointable
- self._handle_deferred_dependencies(name=name, checkpointable=checkpointable)
+ self._unconditional_dependency_names[name] = checkpointable
+ self._handle_deferred_dependencies(
+ name=name, checkpointable=checkpointable)
return checkpointable
def _handle_deferred_dependencies(self, name, checkpointable):
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
new file mode 100644
index 0000000000..62cefa4f20
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -0,0 +1,251 @@
+"""Checkpointable data structures."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import six
+
+from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import data_structures_base
+
+
+# TODO(allenl): We could track regular Python data structures which get assigned
+# to Checkpointable objects. Making this work with restore-on-create would be
+# tricky; we'd need to re-create nested structures with our own wrapped objects
+# on assignment to an attribute, and track the user's original structure to make
+# sure they don't modify it except through the wrappers (since we could save the
+# user's updated structure, but would have no way to support restore-on-create
+# for those modifications).
+# TODO(allenl): A dictionary data structure would be good too.
+class CheckpointableDataStructure(
+ data_structures_base.CheckpointableDataStructureBase):
+ """Base class for data structures which contain checkpointable objects."""
+
+ def __init__(self):
+ self._layers = []
+ self.trainable = True
+
+ def _track_value(self, value, name):
+ """Add a dependency on `value`."""
+ if isinstance(value, checkpointable_lib.CheckpointableBase):
+ self._track_checkpointable(value, name=name)
+ else:
+ raise ValueError(
+ ("Only checkpointable objects (such as Layers or Optimizers) may be "
+ "stored in a List object. Got %s, which does not inherit from "
+ "CheckpointableBase.") % (value,))
+ if isinstance(value, (
+ base_layer.Layer,
+ data_structures_base.CheckpointableDataStructureBase)):
+ if value not in self._layers:
+ self._layers.append(value)
+ if hasattr(value, "_use_resource_variables"):
+ # In subclassed models, legacy layers (tf.layers) must always use
+ # resource variables.
+ value._use_resource_variables = True # pylint: disable=protected-access
+
+ @property
+ def layers(self):
+ return self._layers
+
+ @property
+ def trainable_weights(self):
+ if not self.trainable:
+ return []
+ weights = []
+ for layer in self.layers:
+ weights += layer.trainable_weights
+ return weights
+
+ @property
+ def non_trainable_weights(self):
+ weights = []
+ for layer in self.layers:
+ weights += layer.non_trainable_weights
+ if not self.trainable:
+ trainable_weights = []
+ for layer in self.layers:
+ trainable_weights += layer.trainable_weights
+ return trainable_weights + weights
+ return weights
+
+ @property
+ def weights(self):
+ return self.trainable_weights + self.non_trainable_weights
+
+ @property
+ def variables(self):
+ return self.weights
+
+ @property
+ def updates(self):
+ """Aggregate updates from any `Layer` instances."""
+ # Updates and conditional losses are forwarded as-is rather than being
+ # filtered based on inputs, since this is just a container and won't ever
+ # have any inputs.
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.updates
+ return aggregated
+
+ @property
+ def losses(self):
+ """Aggregate losses from any `Layer` instances."""
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.losses
+ return aggregated
+
+ def __hash__(self):
+ # Support object-identity hashing, so these structures can be used as keys
+ # in sets/dicts.
+ return id(self)
+
+ def __eq__(self, other):
+ # Similar to Tensors, checkpointable data structures use object-identity
+ # equality to support set/dict membership.
+ return self is other
+
+
+class List(CheckpointableDataStructure, collections.Sequence):
+ """An append-only sequence type which is checkpointable.
+
+ Maintains checkpoint dependencies on its contents (which must also be
+ checkpointable), and forwards any `Layer` metadata such as updates and losses.
+
+ Note that `List` is purely a container. It lets a `tf.keras.Model` or
+ other checkpointable object know about its contents, but does not call any
+ `Layer` instances which are added to it. To indicate a sequence of `Layer`
+ instances which should be called sequentially, use `tf.keras.Sequential`.
+
+ Example usage:
+ ```python
+ class HasList(tf.keras.Model):
+
+ def __init__(self):
+ super(HasList, self).__init__()
+ self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
+ self.layer_list.append(layers.Dense(4))
+
+ def call(self, x):
+ aggregation = 0.
+ for l in self.layer_list:
+ x = l(x)
+ aggregation += tf.reduce_sum(x)
+ return aggregation
+ ```
+
+ This kind of wrapping is necessary because `Checkpointable` objects do not
+ (yet) deeply inspect regular Python data structures, so for example assigning
+ a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
+ checkpoint dependency and does not add the `Layer` instance's weights to its
+ parent `Model`.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """Construct a new sequence. Arguments are passed to `list()`."""
+ super(List, self).__init__()
+ self._storage = list(*args, **kwargs)
+ for index, element in enumerate(self._storage):
+ self._track_value(element, name=self._name_element(index))
+
+ def _name_element(self, index):
+ return "%d" % (index,)
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._track_value(value, self._name_element(len(self._storage)))
+ self._storage.append(value)
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ for index_offset, value in enumerate(values):
+ self._track_value(
+ value, name=self._name_element(len(self._storage) + index_offset))
+ self._storage.extend(values)
+
+ def __iadd__(self, values):
+ self.extend(values)
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, List):
+ return List(self._storage + other._storage) # pylint: disable=protected-access
+ else:
+ return List(self._storage + other)
+
+ def __getitem__(self, key):
+ return self._storage[key]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __repr__(self):
+ return "List(%s)" % (repr(self._storage),)
+
+
+class Mapping(CheckpointableDataStructure, collections.Mapping):
+ """An append-only checkpointable mapping data structure with string keys.
+
+ Maintains checkpoint dependencies on its contents (which must also be
+ checkpointable), named based on its keys.
+
+ Note that once a key has been added, it may not be deleted or replaced. If
+ names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """Construct a new sequence. Arguments are passed to `dict()`."""
+ super(Mapping, self).__init__()
+ self._storage = dict(*args, **kwargs)
+ for key, value in self._storage.items():
+ self._track_value(value, name=self._name_element(key))
+
+ def _name_element(self, key):
+ if not isinstance(key, six.string_types):
+ raise TypeError(
+ "Mapping accepts only string keys, but got a key %s."
+ % repr(key))
+ return str(key)
+
+ def __setitem__(self, key, value):
+ current_value = self._storage.setdefault(key, value)
+ if current_value is not value:
+ raise ValueError(
+ ("Mappings are an append-only data structure. Tried to overwrite the "
+ "key '%s' with value %s, but it already contains %s")
+ % (key, value, current_value))
+ self._track_value(value, name=self._name_element(key))
+
+ def update(self, *args, **kwargs):
+ for key, value in dict(*args, **kwargs).items():
+ self[key] = value
+
+ def __getitem__(self, key):
+ return self._storage[key]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __repr__(self):
+ return "Mapping(%s)" % (repr(self._storage),)
+
+ def __iter__(self):
+ return iter(self._storage)
diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/python/training/checkpointable/data_structures_base.py
new file mode 100644
index 0000000000..f1b2cf105b
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/data_structures_base.py
@@ -0,0 +1,27 @@
+"""A trivial base class to avoid circular imports for isinstance checks."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+
+
+class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase):
+ """Base class for data structures which contain checkpointable objects."""
+
+ pass
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
new file mode 100644
index 0000000000..31a0e8b622
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -0,0 +1,219 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.layers import normalization
+from tensorflow.python.layers import core as non_keras_core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training.checkpointable import data_structures
+
+
+class HasList(training.Model):
+
+ def __init__(self):
+ super(HasList, self).__init__()
+ self.layer_list = data_structures.List([core.Dense(3)])
+ self.layer_list.append(core.Dense(4))
+ self.layer_list.extend(
+ [core.Dense(5),
+ core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
+ self.layer_list += [
+ core.Dense(7, bias_regularizer=math_ops.reduce_sum),
+ core.Dense(8)
+ ]
+ self.layer_list += (
+ data_structures.List([core.Dense(9)]) + data_structures.List(
+ [core.Dense(10)]))
+ self.layer_list.extend(
+ data_structures.List(
+ list(sequence=[core.Dense(11)]) + [core.Dense(12)]))
+ self.layers_with_updates = data_structures.List(
+ sequence=(normalization.BatchNormalization(),))
+
+ def call(self, x):
+ aggregation = 0.
+ for l in self.layer_list:
+ x = l(x)
+ aggregation += math_ops.reduce_sum(x)
+ bn, = self.layers_with_updates
+ return bn(x) / aggregation
+
+
+class ListTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTracking(self):
+ model = HasList()
+ output = model(array_ops.ones([32, 2]))
+ self.assertAllEqual([32, 12], output.shape)
+ self.assertEqual(2, len(model.layers))
+ self.assertIs(model.layer_list, model.layers[0])
+ self.assertEqual(10, len(model.layers[0].layers))
+ for index in range(10):
+ self.assertEqual(3 + index, model.layers[0].layers[index].units)
+ self.assertEqual(2, len(model._checkpoint_dependencies))
+ self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
+ self.assertIs(model.layers_with_updates,
+ model._checkpoint_dependencies[1].ref)
+ self.assertEqual(
+ 10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
+ self.evaluate([v.initializer for v in model.variables])
+ self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
+ model.load_weights(save_path)
+ self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
+ self.evaluate(model.variables[0]))
+
+ def testUpdatesForwarded(self):
+ with context.graph_mode():
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertGreater(len(model.layers_with_updates[0].updates), 0)
+ self.assertEqual(set(model.layers_with_updates[0].updates),
+ set(model.updates))
+
+ with context.eager_mode():
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertEqual(0, len(model.updates))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testLossesForwarded(self):
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertEqual(2, len(model.losses))
+
+ def testNotCheckpointable(self):
+ class NotCheckpointable(object):
+ pass
+
+ with self.assertRaises(ValueError):
+ data_structures.List([NotCheckpointable()])
+
+ def testCallNotImplemented(self):
+ with self.assertRaisesRegexp(TypeError, "not callable"):
+ data_structures.List()(1.)
+
+ def testNoPop(self):
+ with self.assertRaises(AttributeError):
+ data_structures.List().pop()
+
+ def testNesting(self):
+ with context.graph_mode():
+ inner = data_structures.List()
+ outer = data_structures.List([inner])
+ inner.append(non_keras_core.Dense(1))
+ inner[0](array_ops.ones([2, 3]))
+ self.assertEqual(2, len(outer.variables))
+ self.assertIsInstance(
+ outer.variables[0],
+ resource_variable_ops.ResourceVariable)
+
+ def testHashing(self):
+ has_sequences = set([data_structures.List(),
+ data_structures.List()])
+ self.assertEqual(2, len(has_sequences))
+ self.assertNotIn(data_structures.List(), has_sequences)
+
+
+class HasMapping(training.Model):
+
+ def __init__(self):
+ super(HasMapping, self).__init__()
+ self.layer_dict = data_structures.Mapping(output=core.Dense(7))
+ self.layer_dict["norm"] = data_structures.List()
+ self.layer_dict["dense"] = data_structures.List()
+ self.layer_dict["dense"].extend(
+ [core.Dense(5),
+ core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
+ self.layer_dict["norm"].append(
+ normalization.BatchNormalization())
+ self.layer_dict["norm"].append(
+ normalization.BatchNormalization())
+
+ def call(self, x):
+ aggregation = 0.
+ for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]):
+ x = norm(dense(x))
+ aggregation += math_ops.reduce_sum(x)
+ return self.layer_dict["output"](x) / aggregation
+
+
+class MappingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTracking(self):
+ model = HasMapping()
+ output = model(array_ops.ones([32, 2]))
+ self.assertAllEqual([32, 7], output.shape)
+ self.assertEqual(1, len(model.layers))
+ self.assertIs(model.layer_dict, model.layers[0])
+ self.assertEqual(3, len(model.layers[0].layers))
+ self.assertEqual(1, len(model._checkpoint_dependencies))
+ self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref)
+ self.evaluate([v.initializer for v in model.variables])
+ test_var = model.layer_dict["output"].kernel
+ self.evaluate(test_var.assign(array_ops.ones([6, 7])))
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ self.evaluate(test_var.assign(array_ops.zeros([6, 7])))
+ model.load_weights(save_path)
+ self.assertAllEqual(numpy.ones([6, 7]),
+ self.evaluate(test_var))
+
+ def testNoOverwrite(self):
+ mapping = data_structures.Mapping()
+ original = data_structures.List()
+ mapping["a"] = original
+ with self.assertRaises(ValueError):
+ mapping["a"] = data_structures.List()
+ self.assertIs(original, mapping["a"])
+ with self.assertRaises(AttributeError):
+ del mapping["a"]
+ mapping.update(b=data_structures.Mapping())
+ with self.assertRaises(ValueError):
+ mapping.update({"b": data_structures.Mapping()})
+
+ def testNonStringKeys(self):
+ mapping = data_structures.Mapping()
+ with self.assertRaises(TypeError):
+ mapping[1] = data_structures.List()
+
+ def testHashing(self):
+ has_mappings = set([data_structures.Mapping(),
+ data_structures.Mapping()])
+ self.assertEqual(2, len(has_mappings))
+ self.assertNotIn(data_structures.Mapping(), has_mappings)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index 6caf29d83a..a07ad19a6e 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -71,6 +71,7 @@ class GradientDescentOptimizer(optimizer.Optimizer):
return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self):
- if not context.executing_eagerly() or self._learning_rate_tensor is None:
+ if not context.executing_eagerly() or not isinstance(
+ self._learning_rate_tensor, ops.EagerTensor):
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
name="learning_rate")
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 5370cafbcf..f89a9c5838 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -218,6 +221,26 @@ class GradientDescentOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
var1.eval())
+ def testCapturingInDefunWhileExecutingEagerly(self):
+ with context.eager_mode():
+ optimizer = gradient_descent.GradientDescentOptimizer(1.0)
+
+ def step():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape() as tape:
+ loss = v ** 2
+ grad = tape.gradient(loss, v)
+ optimizer.apply_gradients([(grad, v)])
+ return v.read_value()
+
+ compiled_step = function.defun(step)
+
+ self.assertEqual(float(step()), -1.0)
+ self.assertEqual(float(compiled_step()), -1.0)
+ # This shouldn't fail; in particular, the learning rate tensor should
+ # be an EagerTensor once again, not a graph Tensor.
+ self.assertEqual(float(step()), -1.0)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 7bd57ad3d8..f7e78071d8 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase):
with context.eager_mode():
self.doTestBasic(use_resource=True, use_callable_params=True)
- @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testVariablesAcrossGraphs(self):
optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5)
with ops.Graph().as_default():
@@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase):
[1.0, 2.0], dtype=dtypes.float32, name="var0")
var1 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var1")
- if context.executing_eagerly():
- loss = lambda: math_ops.reduce_sum(var0 + var1)
- else:
- loss = math_ops.reduce_sum(var0 + var1)
+ loss = math_ops.reduce_sum(var0 + var1)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var0")
@@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase):
[1.0, 2.0], dtype=dtypes.float32, name="var2")
var3 = resource_variable_ops.ResourceVariable(
[3.0, 4.0], dtype=dtypes.float32, name="var3")
- if context.executing_eagerly():
- loss = lambda: math_ops.reduce_sum(var2 + var3)
- else:
- loss = math_ops.reduce_sum(var2 + var3)
+ loss = math_ops.reduce_sum(var2 + var3)
optimizer.minimize(loss)
optimizer_variables = optimizer.variables()
self.assertStartsWith(optimizer_variables[0].name, "var2")
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index 3cb3877cc2..974f75777f 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -95,7 +95,8 @@ class SessionManager(object):
ready_op=None,
ready_for_local_init_op=None,
graph=None,
- recovery_wait_secs=30):
+ recovery_wait_secs=30,
+ local_init_run_options=None):
"""Creates a SessionManager.
The `local_init_op` is an `Operation` that is run always after a new session
@@ -127,6 +128,8 @@ class SessionManager(object):
to run local_init_op.
graph: The `Graph` that the model will use.
recovery_wait_secs: Seconds between checks for the model to be ready.
+ local_init_run_options: RunOptions to be passed to session.run when
+ executing the local_init_op.
Raises:
ValueError: If ready_for_local_init_op is not None but local_init_op is
@@ -141,6 +144,7 @@ class SessionManager(object):
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
+ self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
@@ -485,7 +489,7 @@ class SessionManager(object):
is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
if is_ready_for_local_init:
logging.info("Running local_init_op.")
- sess.run(self._local_init_op)
+ sess.run(self._local_init_op, options=self._local_init_run_options)
logging.info("Done running local_init_op.")
return True, None
else:
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 7389e344c7..372ea415df 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -225,7 +225,8 @@ class Supervisor(object):
checkpoint_basename="model.ckpt",
session_manager=None,
summary_writer=USE_DEFAULT,
- init_fn=None):
+ init_fn=None,
+ local_init_run_options=None):
"""Create a `Supervisor`.
Args:
@@ -294,6 +295,8 @@ class Supervisor(object):
init_fn: Optional callable used to initialize the model. Called
after the optional `init_op` is called. The callable must accept one
argument, the session being initialized.
+ local_init_run_options: RunOptions to be passed as the SessionManager
+ local_init_run_options parameter.
Returns:
A `Supervisor`.
@@ -327,6 +330,7 @@ class Supervisor(object):
self._recovery_wait_secs = recovery_wait_secs
self._stop_grace_secs = stop_grace_secs
self._init_fn = init_fn
+ self._local_init_run_options = local_init_run_options
# Set all attributes related to checkpointing and writing events to None.
# Afterwards, set them appropriately for chief supervisors, as these are
@@ -362,7 +366,8 @@ class Supervisor(object):
ready_op=self._ready_op,
ready_for_local_init_op=self._ready_for_local_init_op,
graph=self._graph,
- recovery_wait_secs=self._recovery_wait_secs)
+ recovery_wait_secs=self._recovery_wait_secs,
+ local_init_run_options=self._local_init_run_options)
else:
self._session_manager = session_manager
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index d05e1d2c83..0877b2a8a2 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -119,18 +119,18 @@ def create_global_step(graph=None):
graph = graph or ops.get_default_graph()
if get_global_step(graph) is not None:
raise ValueError('"global_step" already exists.')
+ if context.executing_eagerly():
+ with ops.device('cpu:0'):
+ return variable_scope.get_variable(
+ ops.GraphKeys.GLOBAL_STEP,
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES,
+ ops.GraphKeys.GLOBAL_STEP])
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
- if context.executing_eagerly():
- with ops.device('cpu:0'):
- return variable_scope.get_variable(
- ops.GraphKeys.GLOBAL_STEP,
- shape=[],
- dtype=dtypes.int64,
- initializer=init_ops.zeros_initializer(),
- trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.GLOBAL_STEP])
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index b0f37f8cb9..ec740abdd1 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -237,6 +237,62 @@ def _warm_start_var_with_vocab(var,
# pylint: enable=protected-access
+def _get_grouped_variables(vars_to_warm_start):
+ """Collects and groups (possibly partitioned) variables into a dictionary.
+
+ The variables can be provided explicitly through vars_to_warm_start, or they
+ are retrieved from collections (see below).
+
+ Args:
+ vars_to_warm_start: One of the following:
+
+ - A regular expression (string) that captures which variables to
+ warm-start (see tf.get_collection). This expression will only consider
+ variables in the TRAINABLE_VARIABLES collection.
+ - A list of Variables to warm-start.
+ - A list of strings, each representing a full variable name to warm-start.
+ - `None`, in which case only variables specified in
+ `var_name_to_vocab_info` will be warm-started.
+ Returns:
+ A dictionary mapping variable names (strings) to lists of Variables.
+ Raises:
+ ValueError: If vars_to_warm_start is not a string, `None`, a list of
+ `Variables`, or a list of strings.
+ """
+ if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
+ # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
+ # everything (in TRAINABLE_VARIABLES) here.
+ list_of_vars = ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES,
+ scope=vars_to_warm_start)
+ elif isinstance(vars_to_warm_start, list):
+ if all([isinstance(v, str) for v in vars_to_warm_start]):
+ list_of_vars = []
+ for v in vars_to_warm_start:
+ list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=v)
+ elif all([_is_variable(v) for v in vars_to_warm_start]):
+ list_of_vars = vars_to_warm_start
+ else:
+ raise ValueError("If `vars_to_warm_start` is a list, it must be all "
+ "`Variable` or all `str`. Given types are {}".format(
+ [type(v) for v in vars_to_warm_start]))
+ else:
+ raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given "
+ "type is {}".format(type(vars_to_warm_start)))
+ # We have to deal with partitioned variables, since get_collection flattens
+ # out the list.
+ grouped_variables = {}
+ for v in list_of_vars:
+ if not isinstance(v, list):
+ var_name = _infer_var_name([v])
+ else:
+ var_name = _infer_var_name(v)
+ grouped_variables.setdefault(var_name, []).append(v)
+
+ return grouped_variables
+
+
@tf_export("train.warm_start")
def warm_start(ckpt_to_initialize_from,
vars_to_warm_start=".*",
@@ -251,10 +307,19 @@ def warm_start(ckpt_to_initialize_from,
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
- vars_to_warm_start: [Optional] A regular expression that captures which
- variables to warm-start (see tf.get_collection). Defaults to `'.*'`,
- which warm-starts all variables. If `None` is explicitly given, only
- variables specified in `var_name_to_vocab_info` will be warm-started.
+ vars_to_warm_start: [Optional] One of the following:
+
+ - A regular expression (string) that captures which variables to
+ warm-start (see tf.get_collection). This expression will only consider
+ variables in the TRAINABLE_VARIABLES collection.
+ - A list of Variables to warm-start.
+ - A list of strings, each representing a full variable name to warm-start.
+ - `None`, in which case only variables specified in
+ `var_name_to_vocab_info` will be warm-started.
+
+ Defaults to `'.*'`, which warm-starts all variables in the
+ TRAINABLE_VARIABLES collection. Note that this excludes variables such as
+ accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
VocabInfo. The variable names should be "full" variables, not the names
of the partitions. If not explicitly provided, the variable is assumed to
@@ -274,21 +339,7 @@ def warm_start(ckpt_to_initialize_from,
if var_name_to_prev_var_name is None:
var_name_to_prev_var_name = {}
logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
- # We have to deal with partitioned variables, since get_collection flattens
- # out the list.
- grouped_variables = {}
- # Both vars_to_warm_start = '.*' and
- # vars_to_warm_start = None will match everything here.
- for v in ops.get_collection(
- # TODO(eddz): Allow for different collections here (to support
- # warm-starting accumulators).
- ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=vars_to_warm_start):
- if not isinstance(v, list):
- var_name = _infer_var_name([v])
- else:
- var_name = _infer_var_name(v)
- grouped_variables.setdefault(var_name, []).append(v)
+ grouped_variables = _get_grouped_variables(vars_to_warm_start)
# Keep track of which var_names in var_name_to_prev_var_name and
# var_name_to_vocab_info have been used. Err on the safer side by throwing an
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 7e8cbd6bae..6a4c207d79 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.training import warm_starting_util as ws_util
ones = init_ops.ones_initializer
norms = init_ops.truncated_normal_initializer
rand = init_ops.random_uniform_initializer
+zeros = init_ops.zeros_initializer
class WarmStartingUtilTest(test.TestCase):
@@ -305,6 +306,46 @@ class WarmStartingUtilTest(test.TestCase):
self.assertAllEqual([[0.5], [0.], [0.]],
fruit_weights_vars[1].eval(sess))
+ def testWarmStart_ListOfVariables(self):
+ # Save checkpoint from which to warm-start.
+ _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
+ initializer=ones())
+ # Verify we initialized the values correctly.
+ self.assertAllEqual(np.ones([10, 1]), prev_int_val)
+
+ # New graph, new session with warm-starting.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ # Initialize with zeros.
+ var = variable_scope.get_variable(
+ "v1",
+ shape=[10, 1],
+ initializer=zeros())
+ ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var])
+ sess.run(variables.global_variables_initializer())
+ # Verify weights were correctly warm-started (init overridden to ones).
+ self.assertAllEqual(var.eval(), prev_int_val)
+
+ def testWarmStart_ListOfStrings(self):
+ # Save checkpoint from which to warm-start.
+ _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
+ initializer=ones())
+ # Verify we initialized the values correctly.
+ self.assertAllEqual(np.ones([10, 1]), prev_int_val)
+
+ # New graph, new session with warm-starting.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ # Initialize with zeros.
+ var = variable_scope.get_variable(
+ "v1",
+ shape=[10, 1],
+ initializer=zeros())
+ ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
+ sess.run(variables.global_variables_initializer())
+ # Verify weights were correctly warm-started (init overridden to ones).
+ self.assertAllEqual(var.eval(), prev_int_val)
+
def testWarmStart_SparseColumnIntegerized(self):
# Create feature column.
sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i
index 6aeaa0e31b..f423553faa 100644
--- a/tensorflow/python/util/stat_summarizer.i
+++ b/tensorflow/python/util/stat_summarizer.i
@@ -73,7 +73,7 @@ void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss);
return ss;
}
}
-
+%include "tensorflow/core/util/stat_summarizer_options.h"
%include "tensorflow/core/util/stat_summarizer.h"
%unignoreall
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 7872f52172..fbd6561767 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -42,32 +42,67 @@ def currentframe():
return _inspect.stack()[1][0]
-def getargspec(object): # pylint: disable=redefined-builtin
+def getargspec(obj):
"""TFDecorator-aware replacement for inspect.getargspec.
Args:
- object: A callable (function or partial function), possibly decorated.
+ obj: A function, partial function, or callable object, possibly
+ decorated.
Returns:
The `ArgSpec` that describes the signature of the outermost decorator that
changes the callable's signature. If the callable is not decorated,
- `inspect.getargspec()` will be called directly on the callable.
+ `inspect.getargspec()` will be called directly on the object.
Raises:
- ValueError: When callable's function signature can not be expressed with
- ArgSpec.
+ ValueError: When callable's signature can not be expressed with
+ ArgSpec.
+ TypeError: For objects of unsupported types.
"""
+ if isinstance(obj, functools.partial):
+ return _get_argspec_for_partial(obj)
- def get_argspec_with_decorator(obj):
- decorators, target = tf_decorator.unwrap(obj)
- return next((d.decorator_argspec
- for d in decorators
- if d.decorator_argspec is not None),
- _inspect.getargspec(target))
+ decorators, target = tf_decorator.unwrap(obj)
+
+ spec = next((d.decorator_argspec
+ for d in decorators
+ if d.decorator_argspec is not None), None)
+ if spec:
+ return spec
+
+ try:
+ # Python3 will handle most callables here (not partial).
+ return _inspect.getargspec(target)
+ except TypeError:
+ pass
+
+ if isinstance(target, type):
+ try:
+ return _inspect.getargspec(target.__init__)
+ except TypeError:
+ pass
+
+ try:
+ return _inspect.getargspec(target.__new__)
+ except TypeError:
+ pass
- if not isinstance(object, functools.partial):
- return get_argspec_with_decorator(object)
+ # The `type(target)` ensures that if a class is received we don't return
+ # the signature of it's __call__ method.
+ return _inspect.getargspec(type(target).__call__)
+
+def _get_argspec_for_partial(obj):
+ """Implements `getargspec` for `functools.partial` objects.
+
+ Args:
+ obj: The `functools.partial` obeject
+ Returns:
+ An `inspect.ArgSpec`
+ Raises:
+ ValueError: When callable's signature can not be expressed with
+ ArgSpec.
+ """
# When callable is a functools.partial object, we construct its ArgSpec with
# following strategy:
# - If callable partial contains default value for positional arguments (ie.
@@ -97,10 +132,10 @@ def getargspec(object): # pylint: disable=redefined-builtin
# value and ensures all following arguments also have default values. When
# this is not true, a ValueError is raised.
- n_prune_args = len(object.args)
- partial_keywords = object.keywords or {}
+ n_prune_args = len(obj.args)
+ partial_keywords = obj.keywords or {}
- args, varargs, keywords, defaults = get_argspec_with_decorator(object.func)
+ args, varargs, keywords, defaults = getargspec(obj.func)
# Pruning first n_prune_args arguments.
args = args[n_prune_args:]
@@ -137,11 +172,34 @@ def getargspec(object): # pylint: disable=redefined-builtin
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
-def getfullargspec(obj): # pylint: disable=redefined-builtin
- """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`.
+if hasattr(_inspect, 'getfullargspec'):
+ _getfullargspec = _inspect.getfullargspec
+else:
+
+ def _getfullargspec(target):
+ """A python2 version of getfullargspec.
+
+ Args:
+ target: the target object to inspect.
+ Returns:
+ A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
+ """
+ argspecs = _inspect.getargspec(target)
+ fullargspecs = FullArgSpec(
+ args=argspecs.args,
+ varargs=argspecs.varargs,
+ varkw=argspecs.keywords,
+ defaults=argspecs.defaults,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ return fullargspecs
+
- This wrapper uses `inspect.getfullargspec` if available and falls back to
- `inspect.getargspec` in Python 2.
+def getfullargspec(obj):
+ """TFDecorator-aware replacement for `inspect.getfullargspec`.
+
+ This wrapper emulates `inspect.getfullargspec` in[^)]* Python2.
Args:
obj: A callable, possibly decorated.
@@ -152,34 +210,10 @@ def getfullargspec(obj): # pylint: disable=redefined-builtin
callable is not decorated, `inspect.getfullargspec()` will be called
directly on the callable.
"""
- if hasattr(_inspect, 'getfullargspec'):
- spec_fn = _inspect.getfullargspec
- else:
- def spec_fn(target):
- """Spec function that adding default value from FullArgSpec.
-
- It is used when getfullargspec is not available (eg in PY2).
-
- Args:
- target: the target object to inspect.
- Returns:
- The full argument specs with empty kwonlyargs, kwonlydefaults and
- annotations.
- """
- argspecs = _inspect.getargspec(target)
- fullargspecs = FullArgSpec(
- args=argspecs.args,
- varargs=argspecs.varargs,
- varkw=argspecs.keywords,
- defaults=argspecs.defaults,
- kwonlyargs=[],
- kwonlydefaults=None,
- annotations={})
- return fullargspecs
-
decorators, target = tf_decorator.unwrap(obj)
- return next((d.decorator_argspec for d in decorators
- if d.decorator_argspec is not None), spec_fn(target))
+ return next((d.decorator_argspec
+ for d in decorators
+ if d.decorator_argspec is not None), _getfullargspec(target))
def getcallargs(func, *positional, **named):
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index 325131c4f4..beaf350de1 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -245,6 +245,52 @@ class TfInspectTest(test.TestCase):
self.assertEqual(partial_argspec,
tf_inspect.getargspec(partial_with_decorator))
+ def testGetArgSpecOnCallableObject(self):
+
+ class Callable(object):
+
+ def __call__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.ArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ keywords=None,
+ defaults=(1, 'hello'))
+
+ test_obj = Callable()
+ self.assertEqual(argspec, tf_inspect.getargspec(test_obj))
+
+ def testGetArgSpecOnInitClass(self):
+
+ class InitClass(object):
+
+ def __init__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.ArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ keywords=None,
+ defaults=(1, 'hello'))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(InitClass))
+
+ def testGetArgSpecOnNewClass(self):
+
+ class NewClass(object):
+
+ def __new__(cls, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.ArgSpec(
+ args=['cls', 'a', 'b', 'c'],
+ varargs=None,
+ keywords=None,
+ defaults=(1, 'hello'))
+
+ self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
+
def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 0f465eda4f..8e839b523e 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -172,17 +172,20 @@ int IsSequenceHelper(PyObject* o) {
// Try not to return to Python - see if the type has already been seen
// before.
- // NOTE: It's not clear whether the lock is required (we should be holding the
- // python GIL in this code already).
- mutex_lock l(g_type_to_sequence_map);
auto* type_to_sequence_map = IsTypeSequenceMap();
auto* type = Py_TYPE(o);
- auto it = type_to_sequence_map->find(type);
- if (it != type_to_sequence_map->end()) {
- return it->second;
+ {
+ mutex_lock l(g_type_to_sequence_map);
+ auto it = type_to_sequence_map->find(type);
+ if (it != type_to_sequence_map->end()) {
+ return it->second;
+ }
}
+ // NOTE: We explicitly release the g_type_to_sequence_map mutex,
+ // because PyObject_IsInstance() may release the GIL, allowing another thread
+ // concurrent entry to this function.
int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
// Don't cache a failed is_instance check.
@@ -195,7 +198,10 @@ int IsSequenceHelper(PyObject* o) {
// leak, as there should only be a relatively small number of types in the
// map, and an even smaller number that are eligible for decref.
Py_INCREF(type);
- type_to_sequence_map->insert({type, is_sequence});
+ {
+ mutex_lock l(g_type_to_sequence_map);
+ type_to_sequence_map->insert({type, is_sequence});
+ }
return is_sequence;
}
diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md
new file mode 100644
index 0000000000..e62757fb5f
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-001.md
@@ -0,0 +1,34 @@
+## TFSA-2018-001: BMP File Parser Out-of-bounds Read.
+
+### CVE Number
+
+CVE-2018-7574
+
+### Issue Description
+
+The BMP (bitmap image file graphics format) decoder had an out-of-bounds read
+due to insufficient checking of header sizes and signed integer values.
+
+### Impact
+
+The most likely consequence of this vulnerability would be that an invalid BMP
+file could lead to an unhandled process crash, but may permit read access to
+unintended regions of the TensorFlow process memory.
+
+### Vulnerable Versions
+
+TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commits
+[https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55](49f73c55).
+If users are running TensorFlow in production or on untrusted data, they are
+encouraged to apply this patch.
+
+Additionally, this patch has already been integrated into TensorFlow 1.7.0 and
+newer.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/advisory/tfsa-2018-002.md b/tensorflow/security/advisory/tfsa-2018-002.md
new file mode 100644
index 0000000000..baf3fb418e
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-002.md
@@ -0,0 +1,33 @@
+## TFSA-2018-002: GIF File Parsing Null Pointer Dereference Error
+
+### CVE Number
+
+CVE-2018-7576
+
+### Issue Description
+
+When parsing certain invalid GIF files, an internal function in the GIF decoder
+returned a null pointer, which was subsequently used as an argument to strcat.
+
+### Impact
+
+A maliciously crafted GIF could be used to cause the TensorFlow process to
+crash.
+
+### Vulnerable Versions
+
+TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1 1.4.1, 1.5.0, 1.5.1
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commit
+[https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8](c4843158).
+If users are running TensorFlow in production or on untrusted data, they are
+encouraged to apply this patch.
+
+Additionally, this patch has already been integrated into TensorFlow 1.6.0 and
+newer.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/advisory/tfsa-2018-003.md b/tensorflow/security/advisory/tfsa-2018-003.md
new file mode 100644
index 0000000000..e20e358f29
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-003.md
@@ -0,0 +1,48 @@
+## TFSA-2018-003: TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability
+
+### CVE Number
+
+CVE-2018-8825
+
+### Issue Description
+
+The TensorFlow Lite TOCO compiler does not perform correct boundary checks when
+reading from some fields within TFLite files.
+
+As background, TFLite files are based on the FlatBuffers serialization format,
+which does not have bounds checking built-in, rather it relies on the clients to
+handle the appropriate security checks by themselves.
+
+In particular, TOCO is not performing correct bounds checks in the following places:
+* Out of bounds read in TOCO in import.cc:42
+* Null dereference in TOCO in import.cc:135
+* Out of bounds read in TOCO in import.cc:104
+* Null dereference in TOCO in import.cc:121
+* Out of bounds read in TOCO in import.cc:62
+* Out of bounds read in TOCO in operator.cc:48
+* Out of bounds read in TOCO graph_transformations (propagate_fixed_sizes.cc:93)
+
+
+### Impact
+
+Users passing a malformed or malicious version of a TFLite graph into TOCO will
+cause TOCO to crash or cause a buffer overflow, potentially allowing malicious
+code to be executed.
+
+### Vulnerable Versions
+
+TensorFlow 1.5.0, 1.5.1, 1.6.0, 1.7.0
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commits [https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and
+[https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and
+If users are running the TensorFlow TFLite TOCO compiler in production or on
+untrusted data, they are encouraged to apply this patch.
+
+Additionally, we have released TensorFlow version 1.7.1 to mitigate this
+vulnerability.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/advisory/tfsa-2018-004.md b/tensorflow/security/advisory/tfsa-2018-004.md
new file mode 100644
index 0000000000..d172247288
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-004.md
@@ -0,0 +1,35 @@
+## TFSA-2018-004: Checkpoint Meta File Out-of-Bounds Read
+
+### CVE Number
+
+CVE-2018-7575
+
+### Issue Description
+
+The block size in meta file might contain a large int64 value which causes
+an integer overflow upon addition. Subsequent code using n as index may cause
+an out-of-bounds read.
+
+### Impact
+
+A maliciously crafted meta checkpoint could be used to cause the TensorFlow
+process to perform an out of bounds read on in process memory.
+
+### Vulnerable Versions
+
+TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commit
+[https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef](d107fee1).
+If users are running TensorFlow on untrusted meta checkpoints, such as those
+downloaded from the Internet, in production or on untrusted data, they are
+encouraged to apply this patch.
+
+Additionally, we have released TensorFlow version 1.7.1 to mitigate this
+vulnerability.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/advisory/tfsa-2018-005.md b/tensorflow/security/advisory/tfsa-2018-005.md
new file mode 100644
index 0000000000..1c91567db5
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-005.md
@@ -0,0 +1,36 @@
+## TFSA-2018-005: Old Snappy Library Usage Resulting in Memcpy Parameter Overlap
+
+### CVE Number
+
+CVE-2018-7577
+
+### Issue Description
+
+TensorFlow checkpoint meta file uses Google's [https://github.com/google/snappy](snappy)
+compression/decompression library. There is a memcpy-param-overlap issue in the
+version of snappy currently used by TensorFlow.
+
+### Impact
+
+A maliciously crafted checkpoint meta file could cause TensorFlow to crash or
+read from other parts of its process memory.
+
+### Vulnerable Versions
+
+TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commit
+[https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994](dfa9921e)
+by upgrading the version of the snappy library used by TensorFlow to v1.1.7.
+
+If users are loading untrusted checkpoints in TensorFlow, we encourage users to
+apply the patch to upgrade snappy.
+
+Additionally, we have released TensorFlow version 1.7.1 to mitigate this
+vulnerability.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/advisory/tfsa-2018-006.md b/tensorflow/security/advisory/tfsa-2018-006.md
new file mode 100644
index 0000000000..a1d1a9f3d1
--- /dev/null
+++ b/tensorflow/security/advisory/tfsa-2018-006.md
@@ -0,0 +1,35 @@
+## TFSA-2018-006: Crafted Configuration File results in Invalid Memory Access
+
+### CVE Number
+
+CVE-2018-10055
+
+### Issue Description
+
+A maliciously crafted configuration file passed into the TensorFlow XLA compiler
+could cause an invalid memory access and/or a heap buffer overflow.
+
+### Impact
+
+A maliciously crafted configuration file could cause TensorFlow to crash or
+read from other parts of its process memory.
+
+### Vulnerable Versions
+
+TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.7.0
+
+### Mitigation
+
+We have patched the vulnerability in GitHub commit
+[https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78](c89ab82a).
+
+If users are loading untrusted configurations in TensorFlow, we encourage users
+to apply the patch to upgrade snappy or upgrade the version of TensorFlow they
+are currently using.
+
+Additionally, we have released TensorFlow version 1.7.1 to mitigate this
+vulnerability.
+
+### Credits
+
+This issue was discovered by the Blade Team of Tencent.
diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md
new file mode 100644
index 0000000000..c1f9f1da74
--- /dev/null
+++ b/tensorflow/security/index.md
@@ -0,0 +1,18 @@
+# TensorFlow Security Advisories
+
+We regularly publish security advisories about using TensorFlow.
+
+*Note*: In conjunction with these security advisories, we strongly encourage
+TensorFlow users to read and understand TensorFlow's security model as outlined
+in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.md).
+
+| Advisory Number | Type | Versions affected | Reported by | Additional Information |
+|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
+| TFSA-2018-006 | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | |
+| TFSA-2018-005 | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | |
+| TFSA-2018-004 | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | |
+| TFSA-2018-003 | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | |
+| TFSA-2018-002 | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | |
+| TFSA-2018-001 | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | |
+| - | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 5ece80e551..c2c0c283b3 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -457,6 +457,9 @@ class ScopedFilterDescriptor {
case dnn::FilterLayout::kOutputInputYX:
format = CUDNN_TENSOR_NCHW;
break;
+ case dnn::FilterLayout::kOutputYXInput:
+ format = CUDNN_TENSOR_NHWC;
+ break;
case dnn::FilterLayout::kOutputInputYX4:
format = CUDNN_TENSOR_NCHW_VECT_C;
break;
@@ -3046,53 +3049,6 @@ bool CudnnSupport::DoFusedConvolve(
output_profile_result);
}
-namespace {
-// NOTE(keveman): Temporary data layout transformation until cuDNN supports
-// kBatchYXDepth for backward pass. This function allocates temporary memory,
-// lays out the source data into the temporary but in the kBatchDepthXY
-// layout, and returns the temporary memory. The caller is responsible for
-// deallocating the temporary. Since the allocation is done using Stream's
-// AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
-// deallocation.
-//
-// transform_scratch is populated with a legitimate temporary allocation iff
-// the original output data needs to be transformed.
-template <class T>
-DeviceMemory<T> MaybeTransformLayout(
- Stream* stream, const CudnnHandle& cudnn,
- dnn::BatchDescriptor* output_descriptor,
- DeviceMemory<T> backward_output_data,
- std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) {
- if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
- return backward_output_data;
- }
- CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
- *transform_scratch =
- stream->AllocateTemporaryArray<T>(backward_output_data.ElementCount())
- .ConsumeValueOrDie();
- dnn::BatchDescriptor transformed_output_descriptor;
- transformed_output_descriptor.CloneFrom(*output_descriptor);
- transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
- cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor orig_out_back_nd(*output_descriptor, cudnn_type);
- ScopedTensorDescriptor transformed_out_back_nd(transformed_output_descriptor,
- cudnn_type);
-
- float alpha = 1.0f;
- float beta = 0.0f;
- auto status = cudnnTransformTensor(
- cudnn.handle(), &alpha, orig_out_back_nd.handle(),
- backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
- (*transform_scratch)->mutable_device_memory()->opaque());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "Failed to transform the data layout.";
- }
- output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
- return (*transform_scratch)->device_memory();
-}
-} // namespace
-
bool CudnnSupport::DoTransformTensor(Stream* stream,
const dnn::BatchDescriptor& input_desc,
dnn::DataType input_type,
@@ -3124,7 +3080,7 @@ template <class T>
bool CudnnSupport::DoConvolveBackwardDataImpl(
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& input_descriptor,
@@ -3145,14 +3101,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- dnn::BatchDescriptor output_descriptor;
- output_descriptor.CloneFrom(output_descriptor_in);
- std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data =
- MaybeTransformLayout(stream, cudnn, &output_descriptor,
- backward_output_data, &transform_scratch);
-
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
@@ -3386,7 +3334,7 @@ template <class T>
bool CudnnSupport::DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
@@ -3407,14 +3355,6 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
- dnn::BatchDescriptor output_descriptor;
- output_descriptor.CloneFrom(output_descriptor_in);
- std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch;
- backward_output_data =
- MaybeTransformLayout(stream, cudnn, &output_descriptor,
- backward_output_data, &transform_scratch);
-
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 38abc66079..3df5365c23 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -349,6 +349,8 @@ enum class FilterLayout : int64 {
kOutputInputYX = 0, // cuDNN's default filter layout, laid out as:
// (major) output feature maps >> input feature maps >>
// rows >> columns (minor).
+ kOutputYXInput, // major to minor:
+ // (output features, row, columns, input features)
kOutputInputYX4, // laid out the same as kOutputInputYX but each element is a
// vector of 4 feature maps.
kInputYXOutput, // Same as dist_belief's default filter layout.
diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt
index 0fb1aaba28..f1dffd5952 100644
--- a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt
@@ -1,108 +1,70 @@
path: "tensorflow.AttrValue.ListValue"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.ListValue\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "B_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FUNC_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "F_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "I_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SHAPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "S_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "ListValue"
+ field {
+ name: "s"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_BYTES
+ }
+ field {
+ name: "i"
+ number: 3
+ label: LABEL_REPEATED
+ type: TYPE_INT64
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "f"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_FLOAT
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "b"
+ number: 5
+ label: LABEL_REPEATED
+ type: TYPE_BOOL
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "type"
+ number: 6
+ label: LABEL_REPEATED
+ type: TYPE_ENUM
+ type_name: ".tensorflow.DataType"
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "shape"
+ number: 7
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ }
+ field {
+ name: "func"
+ number: 9
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NameAttrList"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt
index e7a3a1f02f..6ccd64f428 100644
--- a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt
@@ -1,120 +1,151 @@
path: "tensorflow.AttrValue"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.AttrValue\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "B_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FUNC_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "F_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "I_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ListValue"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "PLACEHOLDER_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SHAPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "S_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "AttrValue"
+ field {
+ name: "s"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "i"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ oneof_index: 0
+ }
+ field {
+ name: "f"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ oneof_index: 0
+ }
+ field {
+ name: "b"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ oneof_index: 0
+ }
+ field {
+ name: "type"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.DataType"
+ oneof_index: 0
+ }
+ field {
+ name: "shape"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ oneof_index: 0
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ oneof_index: 0
+ }
+ field {
+ name: "list"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AttrValue.ListValue"
+ oneof_index: 0
+ }
+ field {
+ name: "func"
+ number: 10
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NameAttrList"
+ oneof_index: 0
+ }
+ field {
+ name: "placeholder"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ oneof_index: 0
+ }
+ nested_type {
+ name: "ListValue"
+ field {
+ name: "s"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_BYTES
+ }
+ field {
+ name: "i"
+ number: 3
+ label: LABEL_REPEATED
+ type: TYPE_INT64
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "f"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_FLOAT
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "b"
+ number: 5
+ label: LABEL_REPEATED
+ type: TYPE_BOOL
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "type"
+ number: 6
+ label: LABEL_REPEATED
+ type: TYPE_ENUM
+ type_name: ".tensorflow.DataType"
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "shape"
+ number: 7
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ }
+ field {
+ name: "func"
+ number: 9
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NameAttrList"
+ }
+ }
+ oneof_decl {
+ name: "value"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt
index 29bb3be35c..d9b1426828 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt
@@ -1,84 +1,21 @@
path: "tensorflow.ConfigProto.DeviceCountEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.DeviceCountEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "DeviceCountEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
new file mode 100644
index 0000000000..9e09a8d48e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.ConfigProto.Experimental"
+tf_proto {
+ descriptor {
+ name: "Experimental"
+ field {
+ name: "collective_group_leader"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
index 009d64aed0..4af4ed70ef 100644
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
@@ -1,140 +1,136 @@
path: "tensorflow.ConfigProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.ConfigProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ALLOW_SOFT_PLACEMENT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CLUSTER_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DEVICE_COUNT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DEVICE_FILTERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DeviceCountEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "GPU_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GRAPH_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INTER_OP_PARALLELISM_THREADS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INTRA_OP_PARALLELISM_THREADS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ISOLATE_SESSION_STATE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LOG_DEVICE_PLACEMENT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OPERATION_TIMEOUT_IN_MS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PLACEMENT_PERIOD_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "RPC_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SESSION_INTER_OP_THREAD_POOL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "USE_PER_SESSION_THREADS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "ConfigProto"
+ field {
+ name: "device_count"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ConfigProto.DeviceCountEntry"
+ }
+ field {
+ name: "intra_op_parallelism_threads"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "inter_op_parallelism_threads"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "use_per_session_threads"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "session_inter_op_thread_pool"
+ number: 12
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ThreadPoolOptionProto"
+ }
+ field {
+ name: "placement_period"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "device_filters"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ field {
+ name: "gpu_options"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GPUOptions"
+ }
+ field {
+ name: "allow_soft_placement"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "log_device_placement"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "graph_options"
+ number: 10
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GraphOptions"
+ }
+ field {
+ name: "operation_timeout_in_ms"
+ number: 11
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "rpc_options"
+ number: 13
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.RPCOptions"
+ }
+ field {
+ name: "cluster_def"
+ number: 14
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ClusterDef"
+ }
+ field {
+ name: "isolate_session_state"
+ number: 15
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "experimental"
+ number: 16
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ConfigProto.Experimental"
+ }
+ nested_type {
+ name: "DeviceCountEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ options {
+ map_entry: true
+ }
+ }
+ nested_type {
+ name: "Experimental"
+ field {
+ name: "collective_group_leader"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt
index 9bf8c12428..3b75a1735b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-event.pbtxt
@@ -1,112 +1,74 @@
path: "tensorflow.Event"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.Event\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FILE_VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GRAPH_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LOG_MESSAGE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "META_GRAPH_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SESSION_LOG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STEP_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SUMMARY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAGGED_RUN_METADATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "WALL_TIME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Event"
+ field {
+ name: "wall_time"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "step"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "file_version"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ oneof_index: 0
+ }
+ field {
+ name: "graph_def"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "summary"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary"
+ oneof_index: 0
+ }
+ field {
+ name: "log_message"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.LogMessage"
+ oneof_index: 0
+ }
+ field {
+ name: "session_log"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SessionLog"
+ oneof_index: 0
+ }
+ field {
+ name: "tagged_run_metadata"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TaggedRunMetadata"
+ oneof_index: 0
+ }
+ field {
+ name: "meta_graph_def"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "what"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
index 875d802a9c..f819b174c0 100644
--- a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
@@ -1,116 +1,86 @@
path: "tensorflow.GPUOptions"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.GPUOptions\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ALLOCATOR_TYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ALLOW_GROWTH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DEFERRED_DELETION_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "EXPERIMENTAL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Experimental"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FORCE_GPU_COMPATIBLE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PER_PROCESS_GPU_MEMORY_FRACTION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "POLLING_ACTIVE_DELAY_USECS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "POLLING_INACTIVE_DELAY_MSECS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VISIBLE_DEVICE_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "GPUOptions"
+ field {
+ name: "per_process_gpu_memory_fraction"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "allow_growth"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "allocator_type"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "deferred_deletion_bytes"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "visible_device_list"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "polling_active_delay_usecs"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "polling_inactive_delay_msecs"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "force_gpu_compatible"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "experimental"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GPUOptions.Experimental"
+ }
+ nested_type {
+ name: "Experimental"
+ field {
+ name: "virtual_devices"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GPUOptions.Experimental.VirtualDevices"
+ }
+ field {
+ name: "use_unified_memory"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ nested_type {
+ name: "VirtualDevices"
+ field {
+ name: "memory_limit_mb"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_FLOAT
+ }
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt
index 1495e847cb..19eccff03d 100644
--- a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt
@@ -1,92 +1,36 @@
path: "tensorflow.GraphDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.graph_pb2.GraphDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "LIBRARY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NODE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VERSIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "GraphDef"
+ field {
+ name: "node"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NodeDef"
+ }
+ field {
+ name: "versions"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.VersionDef"
+ }
+ field {
+ name: "version"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ options {
+ deprecated: true
+ }
+ }
+ field {
+ name: "library"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FunctionDefLibrary"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt
index 0844f891ca..a9f99bc171 100644
--- a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt
@@ -1,112 +1,67 @@
path: "tensorflow.GraphOptions"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.GraphOptions\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "BUILD_COST_MODEL_AFTER_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "BUILD_COST_MODEL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ENABLE_BFLOAT16_SENDRECV_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ENABLE_RECV_SCHEDULING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "INFER_SHAPES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OPTIMIZER_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PLACE_PRUNED_GRAPH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "REWRITE_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TIMELINE_STEP_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "GraphOptions"
+ field {
+ name: "enable_recv_scheduling"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "optimizer_options"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.OptimizerOptions"
+ }
+ field {
+ name: "build_cost_model"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "build_cost_model_after"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "infer_shapes"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "place_pruned_graph"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "enable_bfloat16_sendrecv"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "timeline_step"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "rewrite_options"
+ number: 10
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.RewriterConfig"
+ }
+ reserved_range {
+ start: 1
+ end: 2
+ }
+ reserved_name: "skip_common_subexpression_elimination"
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt
index 2567d2fe60..d4402f330b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt
@@ -1,104 +1,54 @@
path: "tensorflow.HistogramProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.HistogramProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "BUCKET_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "BUCKET_LIMIT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "MAX_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "MIN_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NUM_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SUM_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SUM_SQUARES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "HistogramProto"
+ field {
+ name: "min"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "max"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "num"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "sum"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "sum_squares"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "bucket_limit"
+ number: 6
+ label: LABEL_REPEATED
+ type: TYPE_DOUBLE
+ options {
+ packed: true
+ }
+ }
+ field {
+ name: "bucket"
+ number: 7
+ label: LABEL_REPEATED
+ type: TYPE_DOUBLE
+ options {
+ packed: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt
index a43c5eb7e3..5023aa96bf 100644
--- a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt
@@ -1,112 +1,46 @@
path: "tensorflow.LogMessage"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.LogMessage\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DEBUGGING"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ERROR"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FATAL"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INFO"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LEVEL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Level"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member {
- name: "MESSAGE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "UNKNOWN"
- mtype: "<type \'int\'>"
- }
- member {
- name: "WARN"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "LogMessage"
+ field {
+ name: "level"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.LogMessage.Level"
+ }
+ field {
+ name: "message"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ enum_type {
+ name: "Level"
+ value {
+ name: "UNKNOWN"
+ number: 0
+ }
+ value {
+ name: "DEBUGGING"
+ number: 10
+ }
+ value {
+ name: "INFO"
+ number: 20
+ }
+ value {
+ name: "WARN"
+ number: 30
+ }
+ value {
+ name: "ERROR"
+ number: 40
+ }
+ value {
+ name: "FATAL"
+ number: 50
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt
index 3572126fbf..0ba09bec4b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.MetaGraphDef.CollectionDefEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.CollectionDefEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "CollectionDefEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.CollectionDef"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
index b0e9831154..41c62a407b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
@@ -1,104 +1,50 @@
path: "tensorflow.MetaGraphDef.MetaInfoDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.MetaInfoDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ANY_INFO_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "META_GRAPH_VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STRIPPED_DEFAULT_ATTRS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STRIPPED_OP_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAGS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSORFLOW_GIT_VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSORFLOW_VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "MetaInfoDef"
+ field {
+ name: "meta_graph_version"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "stripped_op_list"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.OpList"
+ }
+ field {
+ name: "any_info"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".google.protobuf.Any"
+ }
+ field {
+ name: "tags"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ field {
+ name: "tensorflow_version"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tensorflow_git_version"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "stripped_default_attrs"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt
index 48fccac99d..73dc414a77 100644
--- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.MetaGraphDef.SignatureDefEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.SignatureDefEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SignatureDefEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SignatureDef"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt
index 3e683a8715..d71c2358c9 100644
--- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt
@@ -1,112 +1,133 @@
path: "tensorflow.MetaGraphDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.MetaGraphDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ASSET_FILE_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "COLLECTION_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CollectionDefEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "GRAPH_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "META_INFO_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "MetaInfoDef"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "SAVER_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SIGNATURE_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SignatureDefEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "MetaGraphDef"
+ field {
+ name: "meta_info_def"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.MetaGraphDef.MetaInfoDef"
+ }
+ field {
+ name: "graph_def"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GraphDef"
+ }
+ field {
+ name: "saver_def"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SaverDef"
+ }
+ field {
+ name: "collection_def"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.MetaGraphDef.CollectionDefEntry"
+ }
+ field {
+ name: "signature_def"
+ number: 5
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.MetaGraphDef.SignatureDefEntry"
+ }
+ field {
+ name: "asset_file_def"
+ number: 6
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AssetFileDef"
+ }
+ nested_type {
+ name: "MetaInfoDef"
+ field {
+ name: "meta_graph_version"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "stripped_op_list"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.OpList"
+ }
+ field {
+ name: "any_info"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".google.protobuf.Any"
+ }
+ field {
+ name: "tags"
+ number: 4
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ field {
+ name: "tensorflow_version"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tensorflow_git_version"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "stripped_default_attrs"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ }
+ nested_type {
+ name: "CollectionDefEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.CollectionDef"
+ }
+ options {
+ map_entry: true
+ }
+ }
+ nested_type {
+ name: "SignatureDefEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SignatureDef"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt
index 2750bd780c..b119b20877 100644
--- a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.NameAttrList.AttrEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.AttrEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "AttrEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AttrValue"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt
index d10faf67d0..fcdb411ffc 100644
--- a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt
@@ -1,88 +1,38 @@
path: "tensorflow.NameAttrList"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.attr_value_pb2.NameAttrList\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ATTR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "AttrEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "NameAttrList"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "attr"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NameAttrList.AttrEntry"
+ }
+ nested_type {
+ name: "AttrEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AttrValue"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt
index b1b62d60f1..622e4c3d0f 100644
--- a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.NodeDef.AttrEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.node_def_pb2.AttrEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "AttrEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AttrValue"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt
index b812b4df2b..646fa8abb9 100644
--- a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt
@@ -1,100 +1,56 @@
path: "tensorflow.NodeDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.node_def_pb2.NodeDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ATTR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "AttrEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DEVICE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "INPUT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OP_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "NodeDef"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "op"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "input"
+ number: 3
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ field {
+ name: "device"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "attr"
+ number: 5
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.NodeDef.AttrEntry"
+ }
+ nested_type {
+ name: "AttrEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.AttrValue"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt
index 6cac5c4d99..3ccf9d459b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt
@@ -1,132 +1,74 @@
path: "tensorflow.OptimizerOptions"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.OptimizerOptions\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DEFAULT"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DO_COMMON_SUBEXPRESSION_ELIMINATION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DO_CONSTANT_FOLDING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DO_FUNCTION_INLINING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "GLOBAL_JIT_LEVEL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GlobalJitLevel"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member {
- name: "L0"
- mtype: "<type \'int\'>"
- }
- member {
- name: "L1"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Level"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member {
- name: "MAX_FOLDED_CONSTANT_IN_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OFF"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ON_1"
- mtype: "<type \'int\'>"
- }
- member {
- name: "ON_2"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OPT_LEVEL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "OptimizerOptions"
+ field {
+ name: "do_common_subexpression_elimination"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "do_constant_folding"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "max_folded_constant_in_bytes"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "do_function_inlining"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "opt_level"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.OptimizerOptions.Level"
+ }
+ field {
+ name: "global_jit_level"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.OptimizerOptions.GlobalJitLevel"
+ }
+ enum_type {
+ name: "Level"
+ value {
+ name: "L1"
+ number: 0
+ }
+ value {
+ name: "L0"
+ number: -1
+ }
+ }
+ enum_type {
+ name: "GlobalJitLevel"
+ value {
+ name: "DEFAULT"
+ number: 0
+ }
+ value {
+ name: "OFF"
+ number: -1
+ }
+ value {
+ name: "ON_1"
+ number: 1
+ }
+ value {
+ name: "ON_2"
+ number: 2
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt
index 808fa0fa21..1287940326 100644
--- a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt
@@ -1,88 +1,27 @@
path: "tensorflow.RunMetadata"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.RunMetadata\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "COST_GRAPH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "PARTITION_GRAPHS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STEP_STATS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "RunMetadata"
+ field {
+ name: "step_stats"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.StepStats"
+ }
+ field {
+ name: "cost_graph"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.CostGraphDef"
+ }
+ field {
+ name: "partition_graphs"
+ number: 3
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.GraphDef"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt
new file mode 100644
index 0000000000..537e73aa89
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.RunOptions.Experimental"
+tf_proto {
+ descriptor {
+ name: "Experimental"
+ field {
+ name: "collective_graph_key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt
index 2f3e7f1a84..cec04a2bf0 100644
--- a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt
@@ -1,120 +1,83 @@
path: "tensorflow.RunOptions"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.config_pb2.RunOptions\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DEBUG_OPTIONS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FULL_TRACE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "HARDWARE_TRACE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INTER_OP_THREAD_POOL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NO_TRACE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "REPORT_TENSOR_ALLOCATIONS_UPON_OOM_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SOFTWARE_TRACE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TIMEOUT_IN_MS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TRACE_LEVEL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TraceLevel"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "RunOptions"
+ field {
+ name: "trace_level"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.RunOptions.TraceLevel"
+ }
+ field {
+ name: "timeout_in_ms"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "inter_op_thread_pool"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "output_partition_graphs"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "debug_options"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.DebugOptions"
+ }
+ field {
+ name: "report_tensor_allocations_upon_oom"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "experimental"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.RunOptions.Experimental"
+ }
+ nested_type {
+ name: "Experimental"
+ field {
+ name: "collective_graph_key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ }
+ enum_type {
+ name: "TraceLevel"
+ value {
+ name: "NO_TRACE"
+ number: 0
+ }
+ value {
+ name: "SOFTWARE_TRACE"
+ number: 1
+ }
+ value {
+ name: "HARDWARE_TRACE"
+ number: 2
+ }
+ value {
+ name: "FULL_TRACE"
+ number: 3
+ }
+ }
+ reserved_range {
+ start: 4
+ end: 5
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt
index ec66d7f335..259f241874 100644
--- a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt
@@ -1,108 +1,44 @@
path: "tensorflow.SessionLog"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.SessionLog\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CHECKPOINT"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CHECKPOINT_PATH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "MSG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "START"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STATUS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STATUS_UNSPECIFIED"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STOP"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SessionStatus"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SessionLog"
+ field {
+ name: "status"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.SessionLog.SessionStatus"
+ }
+ field {
+ name: "checkpoint_path"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "msg"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ enum_type {
+ name: "SessionStatus"
+ value {
+ name: "STATUS_UNSPECIFIED"
+ number: 0
+ }
+ value {
+ name: "START"
+ number: 1
+ }
+ value {
+ name: "STOP"
+ number: 2
+ }
+ value {
+ name: "CHECKPOINT"
+ number: 3
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt
index 067f02ce8c..a66b74b315 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt
@@ -1,84 +1,18 @@
path: "tensorflow.SummaryMetadata.PluginData"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.PluginData\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CONTENT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "PLUGIN_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "PluginData"
+ field {
+ name: "plugin_name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "content"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt
index b9156521cc..c02575b962 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt
@@ -1,92 +1,40 @@
path: "tensorflow.SummaryMetadata"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.SummaryMetadata\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DISPLAY_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "PLUGIN_DATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PluginData"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "SUMMARY_DESCRIPTION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SummaryMetadata"
+ field {
+ name: "plugin_data"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SummaryMetadata.PluginData"
+ }
+ field {
+ name: "display_name"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "summary_description"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ nested_type {
+ name: "PluginData"
+ field {
+ name: "plugin_name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "content"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt
index 781010d75e..94f712073e 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt
@@ -1,96 +1,36 @@
path: "tensorflow.Summary.Audio"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Audio\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CONTENT_TYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ENCODED_AUDIO_STRING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "LENGTH_FRAMES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NUM_CHANNELS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SAMPLE_RATE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Audio"
+ field {
+ name: "sample_rate"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ }
+ field {
+ name: "num_channels"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "length_frames"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "encoded_audio_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ field {
+ name: "content_type"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt
index feb9c7ee92..fc1acb483b 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt
@@ -1,92 +1,30 @@
path: "tensorflow.Summary.Image"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Image\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "COLORSPACE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ENCODED_IMAGE_STRING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "HEIGHT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "WIDTH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Image"
+ field {
+ name: "height"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "width"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "colorspace"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "encoded_image_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt
index ffb4f45fc5..feb84b6ee9 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt
@@ -1,112 +1,74 @@
path: "tensorflow.Summary.Value"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Value\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "AUDIO_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "HISTO_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "IMAGE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "METADATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NODE_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SIMPLE_VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Value"
+ field {
+ name: "node_name"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tag"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "metadata"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SummaryMetadata"
+ }
+ field {
+ name: "simple_value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ oneof_index: 0
+ }
+ field {
+ name: "obsolete_old_style_histogram"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "image"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Image"
+ oneof_index: 0
+ }
+ field {
+ name: "histo"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.HistogramProto"
+ oneof_index: 0
+ }
+ field {
+ name: "audio"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Audio"
+ oneof_index: 0
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "value"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt
index 38de17fa9e..b2bdff7171 100644
--- a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt
@@ -1,92 +1,144 @@
path: "tensorflow.Summary"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Summary\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "Audio"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "Image"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Value"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Summary"
+ field {
+ name: "value"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Value"
+ }
+ nested_type {
+ name: "Image"
+ field {
+ name: "height"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "width"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "colorspace"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "encoded_image_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ }
+ nested_type {
+ name: "Audio"
+ field {
+ name: "sample_rate"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ }
+ field {
+ name: "num_channels"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "length_frames"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "encoded_audio_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ field {
+ name: "content_type"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+ nested_type {
+ name: "Value"
+ field {
+ name: "node_name"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tag"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "metadata"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SummaryMetadata"
+ }
+ field {
+ name: "simple_value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ oneof_index: 0
+ }
+ field {
+ name: "obsolete_old_style_histogram"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "image"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Image"
+ oneof_index: 0
+ }
+ field {
+ name: "histo"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.HistogramProto"
+ oneof_index: 0
+ }
+ field {
+ name: "audio"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Audio"
+ oneof_index: 0
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "value"
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt
index 425c35e067..0064c8460c 100644
--- a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt
@@ -1,88 +1,24 @@
path: "tensorflow.TensorInfo.CooSparse"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.CooSparse\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DENSE_SHAPE_TENSOR_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "INDICES_TENSOR_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUES_TENSOR_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "CooSparse"
+ field {
+ name: "values_tensor_name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "indices_tensor_name"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "dense_shape_tensor_name"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt
index 41ea393be5..63566c808e 100644
--- a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt
@@ -1,96 +1,59 @@
path: "tensorflow.TensorInfo"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.TensorInfo\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "COO_SPARSE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CooSparse"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DTYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_SHAPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "TensorInfo"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ oneof_index: 0
+ }
+ field {
+ name: "coo_sparse"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorInfo.CooSparse"
+ oneof_index: 0
+ }
+ field {
+ name: "dtype"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.DataType"
+ }
+ field {
+ name: "tensor_shape"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ nested_type {
+ name: "CooSparse"
+ field {
+ name: "values_tensor_name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "indices_tensor_name"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "dense_shape_tensor_name"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+ oneof_decl {
+ name: "encoding"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index acc3fc4c5b..87543e374b 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -110,7 +110,7 @@ tf_module {
}
member_method {
name: "non_max_suppression"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'None\'], "
+ argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
name: "pad_to_bounding_box"
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt
index bd5c36f390..e09c44cc9c 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt
@@ -1,80 +1,12 @@
path: "tensorflow.profiler.AdviceProto.Checker"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.Checker\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "REPORTS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Checker"
+ field {
+ name: "reports"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt
index 7c8c68e155..8746243549 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.profiler.AdviceProto.CheckersEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.CheckersEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "CheckersEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.AdviceProto.Checker"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt
index 1b789f4fc9..a8a8858ccd 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt
@@ -1,88 +1,41 @@
path: "tensorflow.profiler.AdviceProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.AdviceProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CHECKERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Checker"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "CheckersEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "AdviceProto"
+ field {
+ name: "checkers"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.AdviceProto.CheckersEntry"
+ }
+ nested_type {
+ name: "CheckersEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.AdviceProto.Checker"
+ }
+ options {
+ map_entry: true
+ }
+ }
+ nested_type {
+ name: "Checker"
+ field {
+ name: "reports"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt
index f0b9605bee..afec73f537 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.profiler.GraphNodeProto.InputShapesEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.InputShapesEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "InputShapesEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt
index b80896a8a0..3c83177005 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt
@@ -1,188 +1,191 @@
path: "tensorflow.profiler.GraphNodeProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.GraphNodeProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CHILDREN_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CPU_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "DEVICES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FLOAT_OPS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INPUT_SHAPES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "InputShapesEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OUTPUT_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PARAMETERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PEAK_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "REQUESTED_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "RESIDUAL_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "RUN_COUNT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SHAPES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_DEFINITION_COUNT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_FLOAT_OPS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_OUTPUT_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_PARAMETERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_PEAK_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_RESIDUAL_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_RUN_COUNT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "GraphNodeProto"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tensor_value"
+ number: 15
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.TFProfTensorProto"
+ }
+ field {
+ name: "run_count"
+ number: 21
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "exec_micros"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "accelerator_exec_micros"
+ number: 17
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "cpu_exec_micros"
+ number: 18
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "requested_bytes"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "peak_bytes"
+ number: 24
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "residual_bytes"
+ number: 25
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "output_bytes"
+ number: 26
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "parameters"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "float_ops"
+ number: 13
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "devices"
+ number: 10
+ label: LABEL_REPEATED
+ type: TYPE_STRING
+ }
+ field {
+ name: "total_definition_count"
+ number: 23
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_run_count"
+ number: 22
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_exec_micros"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_accelerator_exec_micros"
+ number: 19
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_cpu_exec_micros"
+ number: 20
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_requested_bytes"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_peak_bytes"
+ number: 27
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_residual_bytes"
+ number: 28
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_output_bytes"
+ number: 29
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_parameters"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_float_ops"
+ number: 14
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "shapes"
+ number: 11
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ field {
+ name: "input_shapes"
+ number: 16
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.GraphNodeProto.InputShapesEntry"
+ }
+ field {
+ name: "children"
+ number: 12
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.GraphNodeProto"
+ }
+ nested_type {
+ name: "InputShapesEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorShapeProto"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt
index 33deff6497..2b08a05437 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt
@@ -1,160 +1,134 @@
path: "tensorflow.profiler.MultiGraphNodeProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_output_pb2.MultiGraphNodeProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "ACCELERATOR_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CHILDREN_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CPU_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FLOAT_OPS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GRAPH_NODES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OUTPUT_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PARAMETERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PEAK_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "REQUESTED_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "RESIDUAL_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_ACCELERATOR_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_CPU_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_EXEC_MICROS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_FLOAT_OPS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_OUTPUT_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_PARAMETERS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_PEAK_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_REQUESTED_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TOTAL_RESIDUAL_BYTES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "MultiGraphNodeProto"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "exec_micros"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "accelerator_exec_micros"
+ number: 12
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "cpu_exec_micros"
+ number: 13
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "requested_bytes"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "peak_bytes"
+ number: 16
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "residual_bytes"
+ number: 17
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "output_bytes"
+ number: 18
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "parameters"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "float_ops"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_exec_micros"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_accelerator_exec_micros"
+ number: 14
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_cpu_exec_micros"
+ number: 15
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_requested_bytes"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_peak_bytes"
+ number: 19
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_residual_bytes"
+ number: 20
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_output_bytes"
+ number: 21
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_parameters"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "total_float_ops"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "graph_nodes"
+ number: 10
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.GraphNodeProto"
+ }
+ field {
+ name: "children"
+ number: 11
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.MultiGraphNodeProto"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt
index 8c4727cf35..b3adc50c7e 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt
@@ -1,84 +1,21 @@
path: "tensorflow.profiler.OpLogProto.IdToStringEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_log_pb2.IdToStringEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "IdToStringEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt
index 1071a82b5c..7510c566ba 100644
--- a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt
@@ -1,88 +1,38 @@
path: "tensorflow.profiler.OpLogProto"
-tf_class {
- is_instance: "<class \'tensorflow.core.profiler.tfprof_log_pb2.OpLogProto\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "ID_TO_STRING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "IdToStringEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "LOG_ENTRIES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "OpLogProto"
+ field {
+ name: "log_entries"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.OpLogEntry"
+ }
+ field {
+ name: "id_to_string"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.tfprof.OpLogProto.IdToStringEntry"
+ }
+ nested_type {
+ name: "IdToStringEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt
index ab3449d80f..eb99d0f533 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt
@@ -1,112 +1,74 @@
path: "tensorflow.summary.Event"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.Event\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FILE_VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GRAPH_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LOG_MESSAGE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "META_GRAPH_DEF_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SESSION_LOG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STEP_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SUMMARY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAGGED_RUN_METADATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "WALL_TIME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Event"
+ field {
+ name: "wall_time"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_DOUBLE
+ }
+ field {
+ name: "step"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "file_version"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ oneof_index: 0
+ }
+ field {
+ name: "graph_def"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "summary"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary"
+ oneof_index: 0
+ }
+ field {
+ name: "log_message"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.LogMessage"
+ oneof_index: 0
+ }
+ field {
+ name: "session_log"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SessionLog"
+ oneof_index: 0
+ }
+ field {
+ name: "tagged_run_metadata"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TaggedRunMetadata"
+ oneof_index: 0
+ }
+ field {
+ name: "meta_graph_def"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "what"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt
index 92ca4872ca..73de73869c 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt
@@ -1,108 +1,44 @@
path: "tensorflow.summary.SessionLog"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.SessionLog\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CHECKPOINT"
- mtype: "<type \'int\'>"
- }
- member {
- name: "CHECKPOINT_PATH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "MSG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "START"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STATUS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STATUS_UNSPECIFIED"
- mtype: "<type \'int\'>"
- }
- member {
- name: "STOP"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SessionStatus"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SessionLog"
+ field {
+ name: "status"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.SessionLog.SessionStatus"
+ }
+ field {
+ name: "checkpoint_path"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "msg"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ enum_type {
+ name: "SessionStatus"
+ value {
+ name: "STATUS_UNSPECIFIED"
+ number: 0
+ }
+ value {
+ name: "START"
+ number: 1
+ }
+ value {
+ name: "STOP"
+ number: 2
+ }
+ value {
+ name: "CHECKPOINT"
+ number: 3
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt
index f93da2196a..4a8b59cf02 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt
@@ -1,80 +1,12 @@
path: "tensorflow.summary.SummaryDescription"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.SummaryDescription\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "TYPE_HINT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SummaryDescription"
+ field {
+ name: "type_hint"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt
index 605e305e82..8b271cf58f 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt
@@ -1,96 +1,36 @@
path: "tensorflow.summary.Summary.Audio"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Audio\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CONTENT_TYPE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ENCODED_AUDIO_STRING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "LENGTH_FRAMES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NUM_CHANNELS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SAMPLE_RATE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Audio"
+ field {
+ name: "sample_rate"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ }
+ field {
+ name: "num_channels"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "length_frames"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "encoded_audio_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ field {
+ name: "content_type"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt
index 0646972196..dbbc02dd05 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt
@@ -1,92 +1,30 @@
path: "tensorflow.summary.Summary.Image"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Image\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "COLORSPACE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "ENCODED_IMAGE_STRING_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "HEIGHT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "WIDTH_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Image"
+ field {
+ name: "height"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "width"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "colorspace"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "encoded_image_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt
index b319cd03d9..4176171cd9 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt
@@ -1,112 +1,74 @@
path: "tensorflow.summary.Summary.Value"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Value\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "AUDIO_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "HISTO_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "IMAGE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "METADATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "NODE_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "OBSOLETE_OLD_STYLE_HISTOGRAM_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SIMPLE_VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TENSOR_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Value"
+ field {
+ name: "node_name"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tag"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "metadata"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SummaryMetadata"
+ }
+ field {
+ name: "simple_value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ oneof_index: 0
+ }
+ field {
+ name: "obsolete_old_style_histogram"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "image"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Image"
+ oneof_index: 0
+ }
+ field {
+ name: "histo"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.HistogramProto"
+ oneof_index: 0
+ }
+ field {
+ name: "audio"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Audio"
+ oneof_index: 0
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "value"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt
index 132ef1b7d2..d6c5e3a87a 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt
@@ -1,92 +1,144 @@
path: "tensorflow.summary.Summary"
-tf_class {
- is_instance: "<class \'tensorflow.core.framework.summary_pb2.Summary\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "Audio"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "Image"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "Value"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Summary"
+ field {
+ name: "value"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Value"
+ }
+ nested_type {
+ name: "Image"
+ field {
+ name: "height"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "width"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "colorspace"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "encoded_image_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ }
+ nested_type {
+ name: "Audio"
+ field {
+ name: "sample_rate"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ }
+ field {
+ name: "num_channels"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "length_frames"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT64
+ }
+ field {
+ name: "encoded_audio_string"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
+ field {
+ name: "content_type"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+ nested_type {
+ name: "Value"
+ field {
+ name: "node_name"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tag"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "metadata"
+ number: 9
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.SummaryMetadata"
+ }
+ field {
+ name: "simple_value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ oneof_index: 0
+ }
+ field {
+ name: "obsolete_old_style_histogram"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ oneof_index: 0
+ }
+ field {
+ name: "image"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Image"
+ oneof_index: 0
+ }
+ field {
+ name: "histo"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.HistogramProto"
+ oneof_index: 0
+ }
+ field {
+ name: "audio"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Summary.Audio"
+ oneof_index: 0
+ }
+ field {
+ name: "tensor"
+ number: 8
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.TensorProto"
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "value"
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt
index 4dce20819d..27c8873320 100644
--- a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt
@@ -1,84 +1,18 @@
path: "tensorflow.summary.TaggedRunMetadata"
-tf_class {
- is_instance: "<class \'tensorflow.core.util.event_pb2.TaggedRunMetadata\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "RUN_METADATA_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TAG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "TaggedRunMetadata"
+ field {
+ name: "tag"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "run_metadata"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BYTES
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt
index 8cf52b817f..87e4f160e5 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt
@@ -1,80 +1,12 @@
path: "tensorflow.train.BytesList"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.BytesList\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "BytesList"
+ field {
+ name: "value"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_BYTES
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
index 93ff856b09..f9de26839f 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt
@@ -1,80 +1,13 @@
path: "tensorflow.train.ClusterDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.ClusterDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "JOB_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "ClusterDef"
+ field {
+ name: "job"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.JobDef"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt
index f7215a2037..23c30f1ef4 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt
@@ -1,80 +1,13 @@
path: "tensorflow.train.Example"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.example_pb2.Example\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FEATURES_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Example"
+ field {
+ name: "features"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Features"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt
index 3ad98354d6..2a8b3714fc 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt
@@ -1,80 +1,13 @@
path: "tensorflow.train.FeatureList"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.FeatureList\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FEATURE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "FeatureList"
+ field {
+ name: "feature"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Feature"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt
index cd171f4ca3..cd1d56e606 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.train.FeatureLists.FeatureListEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.FeatureListEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "FeatureListEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FeatureList"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt
index 3d95017d58..3c183a6476 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt
@@ -1,84 +1,32 @@
path: "tensorflow.train.FeatureLists"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.FeatureLists\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FEATURE_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "FeatureListEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "FeatureLists"
+ field {
+ name: "feature_list"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FeatureLists.FeatureListEntry"
+ }
+ nested_type {
+ name: "FeatureListEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FeatureList"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt
index 9cca132bba..5d0eb871c2 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt
@@ -1,88 +1,33 @@
path: "tensorflow.train.Feature"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.Feature\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "BYTES_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FLOAT_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "INT64_LIST_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Feature"
+ field {
+ name: "bytes_list"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.BytesList"
+ oneof_index: 0
+ }
+ field {
+ name: "float_list"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FloatList"
+ oneof_index: 0
+ }
+ field {
+ name: "int64_list"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Int64List"
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "kind"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt
index 858aee0341..f912005f1c 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt
@@ -1,84 +1,22 @@
path: "tensorflow.train.Features.FeatureEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.FeatureEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "FeatureEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Feature"
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt
index 49cd12153b..b788ca1d57 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt
@@ -1,84 +1,32 @@
path: "tensorflow.train.Features"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.Features\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FEATURE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "FeatureEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Features"
+ field {
+ name: "feature"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Features.FeatureEntry"
+ }
+ nested_type {
+ name: "FeatureEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Feature"
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt
index e3f01334b5..55d3b46f20 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt
@@ -1,80 +1,15 @@
path: "tensorflow.train.FloatList"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.FloatList\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "FloatList"
+ field {
+ name: "value"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_FLOAT
+ options {
+ packed: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt
index 8917dc122c..1de92b3ab7 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt
@@ -1,80 +1,15 @@
path: "tensorflow.train.Int64List"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.feature_pb2.Int64List\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "Int64List"
+ field {
+ name: "value"
+ number: 1
+ label: LABEL_REPEATED
+ type: TYPE_INT64
+ options {
+ packed: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
index ac6d81541a..58115590a5 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt
@@ -1,84 +1,21 @@
path: "tensorflow.train.JobDef.TasksEntry"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.TasksEntry\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "KEY_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VALUE_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "TasksEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ options {
+ map_entry: true
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
index ce34537fa1..d7eb505e27 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt
@@ -1,88 +1,37 @@
path: "tensorflow.train.JobDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.cluster_pb2.JobDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TASKS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TasksEntry"
- mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "JobDef"
+ field {
+ name: "name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "tasks"
+ number: 2
+ label: LABEL_REPEATED
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.JobDef.TasksEntry"
+ }
+ nested_type {
+ name: "TasksEntry"
+ field {
+ name: "key"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "value"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ options {
+ map_entry: true
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt
index 84498a64f5..4ec99469e4 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt
@@ -1,120 +1,64 @@
path: "tensorflow.train.SaverDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.saver_pb2.SaverDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CheckpointFormatVersion"
- mtype: "<class \'google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FILENAME_TENSOR_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "KEEP_CHECKPOINT_EVERY_N_HOURS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "LEGACY"
- mtype: "<type \'int\'>"
- }
- member {
- name: "MAX_TO_KEEP_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "RESTORE_OP_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SAVE_TENSOR_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "SHARDED_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "V1"
- mtype: "<type \'int\'>"
- }
- member {
- name: "V2"
- mtype: "<type \'int\'>"
- }
- member {
- name: "VERSION_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SaverDef"
+ field {
+ name: "filename_tensor_name"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "save_tensor_name"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "restore_op_name"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "max_to_keep"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "sharded"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "keep_checkpoint_every_n_hours"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_FLOAT
+ }
+ field {
+ name: "version"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_ENUM
+ type_name: ".tensorflow.SaverDef.CheckpointFormatVersion"
+ }
+ enum_type {
+ name: "CheckpointFormatVersion"
+ value {
+ name: "LEGACY"
+ number: 0
+ }
+ value {
+ name: "V1"
+ number: 1
+ }
+ value {
+ name: "V2"
+ number: 2
+ }
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt
index 9ab9553702..6a4553bbc1 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt
@@ -1,84 +1,20 @@
path: "tensorflow.train.SequenceExample"
-tf_class {
- is_instance: "<class \'tensorflow.core.example.example_pb2.SequenceExample\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CONTEXT_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "FEATURE_LISTS_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "SequenceExample"
+ field {
+ name: "context"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.Features"
+ }
+ field {
+ name: "feature_lists"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.FeatureLists"
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt
index af0a3b73cc..83ee7b3eb9 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt
@@ -1,96 +1,38 @@
path: "tensorflow.train.ServerDef"
-tf_class {
- is_instance: "<class \'tensorflow.core.protobuf.tensorflow_server_pb2.ServerDef\'>"
- is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
- member {
- name: "CLUSTER_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DEFAULT_SESSION_CONFIG_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "DESCRIPTOR"
- mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
- }
- member {
- name: "Extensions"
- mtype: "<type \'getset_descriptor\'>"
- }
- member {
- name: "JOB_NAME_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "PROTOCOL_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member {
- name: "TASK_INDEX_FIELD_NUMBER"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "ByteSize"
- }
- member_method {
- name: "Clear"
- }
- member_method {
- name: "ClearExtension"
- }
- member_method {
- name: "ClearField"
- }
- member_method {
- name: "CopyFrom"
- }
- member_method {
- name: "DiscardUnknownFields"
- }
- member_method {
- name: "FindInitializationErrors"
- }
- member_method {
- name: "FromString"
- }
- member_method {
- name: "HasExtension"
- }
- member_method {
- name: "HasField"
- }
- member_method {
- name: "IsInitialized"
- }
- member_method {
- name: "ListFields"
- }
- member_method {
- name: "MergeFrom"
- }
- member_method {
- name: "MergeFromString"
- }
- member_method {
- name: "ParseFromString"
- }
- member_method {
- name: "RegisterExtension"
- }
- member_method {
- name: "SerializePartialToString"
- }
- member_method {
- name: "SerializeToString"
- }
- member_method {
- name: "SetInParent"
- }
- member_method {
- name: "WhichOneof"
- }
- member_method {
- name: "__init__"
+tf_proto {
+ descriptor {
+ name: "ServerDef"
+ field {
+ name: "cluster"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ClusterDef"
+ }
+ field {
+ name: "job_name"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ field {
+ name: "task_index"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
+ field {
+ name: "default_session_config"
+ number: 4
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: ".tensorflow.ConfigProto"
+ }
+ field {
+ name: "protocol"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt
index cc31bb4e4b..448764fe08 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\'], "
+ argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
}
member_method {
name: "prepare_session"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt
index 1f0e59a1ac..9677e5a98e 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt
@@ -104,7 +104,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], "
}
member_method {
name: "loop"
diff --git a/tensorflow/tools/api/lib/api_objects.proto b/tensorflow/tools/api/lib/api_objects.proto
index 0966a5f1d5..7dcde0bbc3 100644
--- a/tensorflow/tools/api/lib/api_objects.proto
+++ b/tensorflow/tools/api/lib/api_objects.proto
@@ -1,5 +1,7 @@
syntax = "proto2";
+import "google/protobuf/descriptor.proto";
+
package third_party.tensorflow.tools.api;
message TFAPIMember {
@@ -24,8 +26,13 @@ message TFAPIClass {
repeated TFAPIMethod member_method = 3;
};
+message TFAPIProto {
+ optional google.protobuf.DescriptorProto descriptor = 1;
+};
+
message TFAPIObject {
optional string path = 1;
optional TFAPIModule tf_module = 2;
optional TFAPIClass tf_class = 3;
+ optional TFAPIProto tf_proto = 4;
};
diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
index 0b30f7b4d1..1cf330e702 100644
--- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
+++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from google.protobuf import message
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -101,6 +102,11 @@ def _SanitizedMRO(obj):
return return_list
+def _IsProtoClass(obj):
+ """Returns whether the passed obj is a Protocol Buffer class."""
+ return isinstance(obj, type) and issubclass(obj, message.Message)
+
+
class PythonObjectToProtoVisitor(object):
"""A visitor that summarizes given python objects as protobufs."""
@@ -153,6 +159,13 @@ class PythonObjectToProtoVisitor(object):
# Store the constructed module object.
self._protos[lib_path] = api_objects_pb2.TFAPIObject(
path=lib_path, tf_module=module_obj)
+ elif _IsProtoClass(parent):
+ proto_obj = api_objects_pb2.TFAPIProto()
+ parent.DESCRIPTOR.CopyToProto(proto_obj.descriptor)
+
+ # Store the constructed proto object.
+ self._protos[lib_path] = api_objects_pb2.TFAPIObject(
+ path=lib_path, tf_proto=proto_obj)
elif tf_inspect.isclass(parent):
# Construct a class.
class_obj = api_objects_pb2.TFAPIClass()
@@ -161,7 +174,7 @@ class PythonObjectToProtoVisitor(object):
if name in parent_corner_cases:
# If we have an empty entry, skip this object.
if parent_corner_cases[name]:
- module_obj.member.add(**(parent_corner_cases[name]))
+ class_obj.member.add(**(parent_corner_cases[name]))
else:
_AddMember(name, child, class_obj)
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 1ad6b6d1c0..90375a794f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -35,6 +35,7 @@ import unittest
import tensorflow as tf
+from google.protobuf import message
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
@@ -195,6 +196,25 @@ class ApiCompatibilityTest(test.TestCase):
else:
logging.info('No differences found between API and golden.')
+ def testNoSubclassOfMessage(self):
+
+ def Visit(path, parent, unused_children):
+ """A Visitor that crashes on subclasses of generated proto classes."""
+ # If the traversed object is a proto Message class
+ if not (isinstance(parent, type) and
+ issubclass(parent, message.Message)):
+ return
+ if parent is message.Message:
+ return
+ # Check that it is a direct subclass of Message.
+ if message.Message not in parent.__bases__:
+ raise NotImplementedError(
+ 'Object tf.%s is a subclass of a generated proto Message. '
+ 'They are not yet supported by the API tools.' % path)
+ visitor = public_api.PublicAPIVisitor(Visit)
+ visitor.do_not_descend_map['tf'].append('contrib')
+ traverse.traverse(tf, visitor)
+
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc
index eeb1fab40c..de93b12b97 100644
--- a/tensorflow/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/tools/benchmark/benchmark_model.cc
@@ -667,12 +667,12 @@ int Main(int argc, char** argv) {
output_prefix, benchmark_name, "meta-init-plus-first-inference", 1,
initialization_time_s + (warmup_time_us / 1000000.0) / warmup_runs);
- std::map<string, int64> node_type_map_count;
- std::map<string, int64> node_type_map_time;
- std::map<string, int64> node_type_map_memory;
- std::map<string, int64> node_type_map_times_called;
+ std::map<std::string, int64_t> node_type_map_count;
+ std::map<std::string, int64_t> node_type_map_time;
+ std::map<std::string, int64_t> node_type_map_memory;
+ std::map<std::string, int64_t> node_type_map_times_called;
- int64 accumulated_us;
+ int64_t accumulated_us;
stats->ComputeStatsByType(&node_type_map_count, &node_type_map_time,
&node_type_map_memory,
&node_type_map_times_called, &accumulated_us);
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index 51e10f81f8..8eeddcdb82 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -34,5 +34,5 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc,java -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \
- --test_output=errors -- \
+ --test_output=errors --test_size_filters=small,medium -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
index ea14848b1a..8eca1987f0 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
@@ -33,5 +33,5 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \
- --test_output=errors -- \
+ --test_output=errors --test_size_filters=small,medium -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index c798081250..2b68de3c5b 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \
- --test_output=errors -- \
+ --test_size_filters=small,medium --test_output=errors -- \
//tensorflow/contrib/... \
-//tensorflow/contrib/lite/... \
//tensorflow/contrib/lite:context_test \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
index a9accb9dd5..51eb2cd7e6 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
@@ -33,5 +33,5 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \
- --test_output=errors -- \
+ --test_output=errors --test_size_filters=small,medium -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
index 02224d8e9d..9d2c8383fa 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
@@ -37,5 +37,6 @@ yes "" | $PYTHON_BIN_PATH configure.py
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
+ --test_size_filters=small,medium \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
index 0367a53d14..5b3383e105 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
@@ -37,5 +37,6 @@ yes "" | $PYTHON_BIN_PATH configure.py
bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
--test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
+ --test_size_filters=small,medium \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index a2300811bb..73520bb2ac 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -54,18 +54,24 @@ function cleanup {
trap cleanup EXIT
skip_test=0
+release_build=0
for ARG in "$@"; do
if [[ "$ARG" == --skip_test ]]; then
skip_test=1
elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then
set_gcs_remote_cache_options
+ elif [[ "$ARG" == --release_build ]]; then
+ release_build=1
fi
done
-# --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
-# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
-echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+if [[ "$release_build" != 1 ]]; then
+ # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
+ # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
+ # Because this hurts the performance of TF, we don't enable it in release build.
+ echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+fi
echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc
diff --git a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
index a410c10b61..d085e21b03 100755
--- a/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
+++ b/tensorflow/tools/ci_build/xla/linux/gpu/run_py3.sh
@@ -37,6 +37,7 @@ bazel clean
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --test_size_filters=small,medium \
--build_tests_only --test_output=errors --local_test_jobs=8 \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
--config=xla -- \
diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh
index 225c034741..345217d733 100755
--- a/tensorflow/tools/dist_test/build_server.sh
+++ b/tensorflow/tools/dist_test/build_server.sh
@@ -23,7 +23,7 @@
# E.g.: tensorflow/tf_grpc_test_server:0.11.0rc1
#
# whl_file_location: URL from which the TensorFlow whl file will be downloaded.
-# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
+# E.g.: https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl
# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
#
# The optional flag --test lets the script to use the Dockerfile for the
diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh
index 99e09502be..b0114721bd 100755
--- a/tensorflow/tools/dist_test/local_test.sh
+++ b/tensorflow/tools/dist_test/local_test.sh
@@ -35,7 +35,7 @@
#
# Arguments:
# whl_file_location: URL from which the TensorFlow whl file will be acquired.
-# E.g.: https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
+# E.g.: https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.5.0-cp27-none-linux_x86_64.whl
# E.g.: /path/to/folder/tensorflow-0.11.0rc1-cp27-none-linux_x86_64.whl
#
# --leave_container_running: Do not stop the docker-in-docker container after
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 111d54d820..853ec6194f 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -50,7 +50,11 @@ def _is_free_function(py_object, full_name, index):
return True
-def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'):
+def write_docs(output_dir,
+ parser_config,
+ yaml_toc,
+ root_title='TensorFlow',
+ search_hints=True):
"""Write previously extracted docs to disk.
Write a docs page for each symbol included in the indices of parser_config to
@@ -66,6 +70,8 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'):
indices.
yaml_toc: Set to `True` to generate a "_toc.yaml" file.
root_title: The title name for the root level index.md.
+ search_hints: (bool) include meta-data search hints at the top of each
+ output file.
Raises:
ValueError: if `output_dir` is not an absolute path
@@ -134,7 +140,13 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow'):
if not os.path.exists(directory):
os.makedirs(directory)
# This function returns raw bytes in PY2 or unicode in PY3.
- text = pretty_docs.build_md_page(page_info)
+ if search_hints:
+ content = [page_info.get_metadata_html()]
+ else:
+ content = ['']
+
+ content.append(pretty_docs.build_md_page(page_info))
+ text = '\n'.join(content)
if six.PY3:
text = text.encode('utf-8')
with open(path, 'wb') as f:
@@ -467,6 +479,12 @@ class DocGenerator(object):
self._do_not_descend_map = _get_default_do_not_descend_map()
self.yaml_toc = True
+ self.argument_parser.add_argument(
+ '--no_search_hints',
+ dest='search_hints',
+ action='store_false',
+ default=True)
+
def add_output_dir_argument(self):
self.argument_parser.add_argument(
'--output_dir',
@@ -553,7 +571,8 @@ class DocGenerator(object):
output_dir,
parser_config,
yaml_toc=self.yaml_toc,
- root_title=root_title)
+ root_title=root_title,
+ search_hints=getattr(flags, 'search_hints', True))
_other_docs(flags.src_dir, flags.output_dir, reference_resolver)
parser_config.reference_resolver.log_errors()
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index fb0bd2c2ff..50c9052741 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import ast
import collections
import functools
+import itertools
import json
import os
import re
@@ -614,6 +615,9 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver):
docstring, compatibility = _handle_compatibility(raw_docstring)
docstring, function_details = _parse_function_details(docstring)
+ if 'Generated by: tensorflow/tools/api/generator' in docstring:
+ docstring = ''
+
return _DocstringInfo(
docstring.split('\n')[0], docstring, function_details, compatibility)
@@ -906,6 +910,9 @@ class _FunctionPageInfo(object):
def add_decorator(self, dec):
self._decorators.append(dec)
+ def get_metadata_html(self):
+ return _Metadata(self.full_name).build_html()
+
class _ClassPageInfo(object):
"""Collects docs for a class page.
@@ -1099,6 +1106,14 @@ class _ClassPageInfo(object):
"""Returns a list of `_LinkInfo` pointing to any nested classes."""
return self._classes
+ def get_metadata_html(self):
+ meta_data = _Metadata(self.full_name)
+ for item in itertools.chain(self.classes, self.properties, self.methods,
+ self.other_members):
+ meta_data.append(item)
+
+ return meta_data.build_html()
+
def _add_class(self, short_name, full_name, obj, doc, url):
"""Adds a `_LinkInfo` for a nested class to `classes` list.
@@ -1330,6 +1345,16 @@ class _ModulePageInfo(object):
self._other_members.append(
_OtherMemberInfo(short_name, full_name, obj, doc))
+ def get_metadata_html(self):
+ meta_data = _Metadata(self.full_name)
+
+ # Objects with their own pages are not added to the matadata list for the
+ # module, the module only has a link to the object page. No docs.
+ for item in self.other_members:
+ meta_data.append(item)
+
+ return meta_data.build_html()
+
def collect_docs_for_module(self, parser_config):
"""Collect information necessary specifically for a module's doc page.
@@ -1575,7 +1600,8 @@ class _GeneratedFile(object):
return True
def __str__(self):
- return 'Defined in `%s%s`.\n\n' % (self.path_prefix, self.path)
+ return 'Defined in generated file: `%s%s`.\n\n' % (self.path_prefix,
+ self.path)
def _get_defined_in(py_object, parser_config):
@@ -1612,6 +1638,8 @@ def _get_defined_in(py_object, parser_config):
if re.match(r'.*/gen_[^/]*\.py$', path):
return _GeneratedFile(path, parser_config)
+ if 'genfiles' in path or 'tools/api/generator' in path:
+ return _GeneratedFile(path, parser_config)
elif re.match(r'.*_pb2\.py$', path):
# The _pb2.py files all appear right next to their defining .proto file.
return _ProtoFile(path[:-7] + '.proto', parser_config)
@@ -1656,3 +1684,41 @@ def generate_global_index(library_name, index, reference_resolver):
# TODO(markdaoust): use a _ModulePageInfo -> prety_docs.build_md_page()
return '\n'.join(lines)
+
+
+class _Metadata(object):
+ """A class for building a page's Metadata block.
+
+ Attributes:
+ name: The name of the page being described by the Metadata block.
+ """
+
+ def __init__(self, name):
+ """Creates a Metadata builder.
+
+ Args:
+ name: The name of the page being described by the Metadata block.
+ """
+ self.name = name
+ self._content = []
+
+ def append(self, item):
+ """Adds an item from the page to the Metadata block.
+
+ Args:
+ item: The parsed page section to add.
+ """
+ self._content.append(item.short_name)
+
+ def build_html(self):
+ """Returns the Metadata block as an Html string."""
+ schema = 'http://developers.google.com/ReferenceObject'
+ parts = ['<div itemscope itemtype="%s">' % schema]
+
+ parts.append('<meta itemprop="name" content="%s" />' % self.name)
+ for item in self._content:
+ parts.append('<meta itemprop="property" content="%s"/>' % item)
+
+ parts.extend(['</div>', ''])
+
+ return '\n'.join(parts)
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index 55ab5bdd49..63d4fef91c 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -27,7 +27,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import itertools
import textwrap
@@ -58,8 +57,7 @@ def build_md_page(page_info):
def _build_function_page(page_info):
"""Given a FunctionPageInfo object Return the page as an md string."""
- parts = [_Metadata(page_info.full_name).build_html()]
- parts.append('# %s\n\n' % page_info.full_name)
+ parts = ['# %s\n\n' % page_info.full_name]
if len(page_info.aliases) > 1:
parts.append('### Aliases:\n\n')
@@ -83,17 +81,7 @@ def _build_function_page(page_info):
def _build_class_page(page_info):
"""Given a ClassPageInfo object Return the page as an md string."""
- meta_data = _Metadata(page_info.full_name)
- for item in itertools.chain(
- page_info.classes,
- page_info.properties,
- page_info.methods,
- page_info.other_members):
- meta_data.append(item)
-
- parts = [meta_data.build_html()]
-
- parts.append('# {page_info.full_name}\n\n'.format(page_info=page_info))
+ parts = ['# {page_info.full_name}\n\n'.format(page_info=page_info)]
parts.append('## Class `%s`\n\n' % page_info.full_name.split('.')[-1])
if page_info.bases:
@@ -186,17 +174,7 @@ def _build_class_page(page_info):
def _build_module_page(page_info):
"""Given a ClassPageInfo object Return the page as an md string."""
- meta_data = _Metadata(page_info.full_name)
-
- # Objects with their own pages are not added to the matadata list for the
- # module, as the only thing on the module page is a link to the object's page.
- for item in page_info.other_members:
- meta_data.append(item)
-
- parts = [meta_data.build_html()]
-
- parts.append(
- '# Module: {full_name}\n\n'.format(full_name=page_info.full_name))
+ parts = ['# Module: {full_name}\n\n'.format(full_name=page_info.full_name)]
if len(page_info.aliases) > 1:
parts.append('### Aliases:\n\n')
@@ -317,41 +295,3 @@ def _build_function_details(function_details):
parts.append(''.join(sub))
return '\n'.join(parts)
-
-
-class _Metadata(object):
- """A class for building a page's Metadata block.
-
- Attributes:
- name: The name of the page being described by the Metadata block.
- """
-
- def __init__(self, name):
- """Create a Metadata builder.
-
- Args:
- name: The name of the page being described by the Metadata block.
- """
- self.name = name
- self._content = []
-
- def append(self, item):
- """Add an item from the page to the Metadata block.
-
- Args:
- item: The parsed page section to add.
- """
- self._content.append(item.short_name)
-
- def build_html(self):
- """Return the Metadata block as an Html string."""
- schema = 'http://developers.google.com/ReferenceObject'
- parts = ['<div itemscope itemtype="%s">' % schema]
-
- parts.append('<meta itemprop="name" content="%s" />' % self.name)
- for item in self._content:
- parts.append('<meta itemprop="property" content="%s"/>' % item)
-
- parts.extend(['</div>', '', ''])
-
- return '\n'.join(parts)
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
index 7651a03fe5..435f46c107 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
@@ -191,7 +191,7 @@ class FoldOldBatchNormsTest : public ::testing::Test {
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
- test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("FusedBatchNorm", node.op());
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 286459d01c..2dc73ca7be 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -201,7 +201,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
urls = [
"https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
"http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
],
sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324",
strip_prefix = "nasm-2.12.02",
@@ -453,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/638915a37f90f26599941977846408864f70ab35.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/638915a37f90f26599941977846408864f70ab35.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz",
],
- sha256 = "aae3cacefa318cef030b4ca1e81ee9906752bbd89013cf9d47e156b5ad04b3a5",
- strip_prefix = "llvm-638915a37f90f26599941977846408864f70ab35",
+ sha256 = "03db53e502dd4fbdbbf1c470776315eeff665180ade32859cfb6c1e996bbf2a5",
+ strip_prefix = "llvm-d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index 017613abb0..a058c46cc4 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -34,7 +34,7 @@ filegroup(
"@org_tensorflow//tensorflow:windows": [
"@mkl_windows//:LICENSE",
],
- "//conditions:default": []
+ "//conditions:default": [],
}),
visibility = ["//visibility:public"],
)
@@ -55,6 +55,6 @@ cc_library(
"@mkl_windows//:mkl_headers",
"@mkl_windows//:mkl_libs_windows",
],
- "//conditions:default": []
+ "//conditions:default": [],
}),
)
diff --git a/util/python/BUILD b/third_party/python_runtime/BUILD
index f5fa0c6d29..2a1609191f 100644
--- a/util/python/BUILD
+++ b/third_party/python_runtime/BUILD
@@ -3,6 +3,6 @@ licenses(["notice"]) # New BSD, Python Software Foundation
package(default_visibility = ["//visibility:public"])
alias(
- name = "python_headers",
+ name = "headers",
actual = "@local_config_python//:python_headers",
)